Merge pull request from tensorflow/master

downstream merge
This commit is contained in:
tg-at-google 2020-07-22 14:40:47 -04:00 committed by GitHub
commit 3bb28df8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1812 changed files with 62447 additions and 25941 deletions
.bazelrcRELEASE.md
tensorflow
BUILDapi_template.__init__.pyapi_template_v1.__init__.py
c
cc/saved_model/experimental/tests
compiler

View File

@ -78,7 +78,16 @@
# elinux: General Embedded Linux options shared by all flavors.
# elinux_aarch64: Embedded Linux options for aarch64 (ARM64) CPU support.
# elinux_armhf: Embedded Linux options for armhf (ARMv7) CPU support.
#
# Release build options (for all operating systems)
# release_common: Common options for all builds on all operating systems.
# release_windows_common: Common options for all builds on Windows.
# release_gpu_common: Common options for GPU builds on Linux and Windows.
# release_cpu_linux: Toolchain and CUDA options for Linux CPU builds.
# release_cpu_macos: Toolchain and CUDA options for MacOS CPU builds.
# release_gpu_linux: Toolchain and CUDA options for Linux GPU builds.
# release_cpu_windows: Toolchain and CUDA options for Windows CPU builds.
# release_gpu_windows: Toolchain and CUDA options for Windows GPU builds.
# Allow builds using libc++ as a linker library
# This is mostly for OSSFuzz, so we also pass in the flags from environment to clean build file
@ -534,3 +543,41 @@ try-import %workspace%/.tf_configure.bazelrc
# Put user-specific options in .bazelrc.user
try-import %workspace%/.bazelrc.user
# Here are bazelrc configs for release builds
build:release_common --config=opt
build:release_common --config=v2
build:release_common --action_env TF_CONFIGURE_IOS="0"
build:release_cpu_linux --config=release_common
build:release_cpu_linux --config=avx_linux
# We use the same toolchain for CPU/GPU packages.
# Did not add this to the defaults in case this changes.
build:release_cpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_cpu_macos --config=release_common
build:release_cpu_macos --config=avx_linux
build:release_gpu_common --config=release_common
build:release_gpu_common --config=cuda
build:release_gpu_common --config=tensorrt
build:release_gpu_common --action_env CUDA_TOOLKIT_PATH="/usr/local/cuda-10.1"
build:release_gpu_common --action_env=TF_CUDA_VERSION="10"
build:release_gpu_common --action_env=TF_CUDNN_VERSION="7"
build:release_gpu_common --action_env=TF_CUDA_COMPUTE_CAPABILITIES="sm_35,sm_37,sm_52,sm_60,sm_61,compute_70"
build:release_gpu_common --action_env=TENSORRT_INSTALL_PATH="/usr/local/tensorrt"
build:release_gpu_common --action_env=LD_LIBRARY_PATH="/usr/local/tensorrt/lib"
build:release_gpu_common --action_env=GCC_HOST_COMPILER_PATH="/usr/bin/gcc-5"
build:release_gpu_linux --config=release_gpu_common
build:release_gpu_linux --config=avx_linux
build:release_gpu_linux --crosstool_top=//third_party/toolchains/preconfig/ubuntu16.04/gcc7_manylinux2010-nvcc-cuda10.1:toolchain
build:release_windows_common --config=release_common
build:release_windows_common --define=no_tensorflow_py_deps=true
build:release_windows_common --announce_rc
build:release_cpu_windows --config=release_windows_common
build:release_gpu_windows --config=release_windows_common

View File

@ -11,6 +11,9 @@
* C-API functions `TF_StringDecode`, `TF_StringEncode`, and
`TF_StringEncodedSize` are no longer relevant and have been removed; see
core/platform/ctstring.h for string access/modification in C.
* Removed `tf.distribute.Strategy.experimental_run_v2` method, which was deprecated in TF 2.2.
* `tensorflow.python`, `tensorflow.core` and `tensorflow.compiler` modules are
now hidden. These modules are not part of TensorFlow public API.
## Known Caveats
@ -20,6 +23,7 @@
* <INSERT MAJOR FEATURE HERE, USING MARKDOWN SYNTAX>
* <IF RELEASE CONTAINS MULTIPLE FEATURES FROM SAME AREA, GROUP THEM TOGETHER>
* A new module named `tf.experimental.numpy` is added, which is a NumPy-compatible API for writing TF programs. This module provides class `ndarray`, which mimics the `ndarray` class in NumPy, and wraps an immutable `tf.Tensor` under the hood. A subset of NumPy functions (e.g. `numpy.add`) are provided. Their inter-operation with TF facilities is seamless in most cases. See tensorflow/python/ops/numpy_ops/README.md for details of what are supported and what are the differences with NumPy.
## Bug Fixes and Other Changes
@ -27,10 +31,22 @@
* <IF A CHANGE CLOSES A GITHUB ISSUE, IT SHOULD BE DOCUMENTED HERE>
* <NOTES SHOULD BE GROUPED PER AREA>
* TF Core:
* <ADD RELEASE NOTES HERE>
* <ADD RELEASE NOTES HERE>
* `tf.types.experimental.TensorLike` is a new `Union` type that can be used as
type annotation for variables representing a Tensor or a value that can be
converted to Tensor by `tf.convert_to_tensor`.
* Calling ops with a python constants or numpy values is now consistent with
tf.convert_to_tensor behavior. This avoids operations like tf.reshape
truncating inputs such as from int64 to int32.
* Added `tf.sparse.map_values` to apply a function to the `.value`s of `SparseTensror` arguments.
* `tf.data`:
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
transformations such as `take` and `shard` to happen earlier in the
dataset when it is safe to do so. The optimization can be disabled via
the `experimental_optimization.reorder_data_discarding_ops` dataset
option.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
@ -38,7 +54,8 @@
* `tf.function`/AutoGraph:
* <ADD RELEASE NOTES HERE>
* `tf.lite`:
* <ADD RELEASE NOTES HERE>
* Better support for ops with high-dimensional broadcasting inputs by adding
`BroadcastTo` ops when necessary.
* `tf.random`:
* <ADD RELEASE NOTES HERE>
* Math and Linear Algebra:
@ -50,9 +67,9 @@
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>
* Other:
* We have replaced uses of "whitelist" with "allowlist" where possible.
Please see https://developers.google.com/style/word-list#blacklist for more
context.
* We have replaced uses of "whitelist" and "blacklist" with "allowlist"
and "denylist" where possible. Please see
https://developers.google.com/style/word-list#blacklist for more context.
* <ADD RELEASE NOTES HERE>
## Thanks to our Contributors

View File

@ -532,16 +532,14 @@ selects.config_setting_group(
package_group(
name = "internal",
packages = [
# To pass open source testing in the pip Kokoros.
"//bazel_pip/tensorflow/...",
"//learning/brain/swift/x10/...",
"//perftools/accelerators/xprof/api/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
"//tensorflow/...",
"//tensorflow_estimator/python/estimator/...",
"//tensorflow_models/official/...",
"//third_party/py/autograph/...",
"//third_party/swift/tensorflow/x10/...",
"//third_party/swift/tensorflow_apis/...",
],
)

View File

@ -158,4 +158,23 @@ if hasattr(_current_module, 'keras'):
setattr(_current_module, "initializers", initializers)
# pylint: enable=undefined-variable
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# __all__ PLACEHOLDER

View File

@ -156,4 +156,25 @@ if _running_from_pip_package():
if _fi.file_exists(_plugin_dir):
_ll.load_library(_plugin_dir)
# Delete modules that should be hidden from dir().
# Don't fail if these modules are not available.
# For e.g. this file will be originally placed under tensorflow/_api/v1 which
# does not have 'python', 'core' directories. Then, it will be copied
# to tensorflow/ which does have these two directories.
# pylint: disable=undefined-variable
try:
del python
except NameError:
pass
try:
del core
except NameError:
pass
try:
del compiler
except NameError:
pass
# __all__ PLACEHOLDER

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor.pb.h"
@ -525,12 +526,12 @@ tensorflow::Status EnableCollectiveOps(const tensorflow::ServerDef& server_def,
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
std::move(new_server), grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
grpc_server->worker_env()->collective_executor_mgr.get()));
} else {
LOG_AND_RETURN_IF_ERROR(grpc_server->UpdateServerDef(server_def));
LOG_AND_RETURN_IF_ERROR(context->StoreCollectiveOpsServer(
/*new_server=*/nullptr, grpc_server->worker_env()->device_mgr,
grpc_server->worker_env()->collective_executor_mgr));
grpc_server->worker_env()->collective_executor_mgr.get()));
}
return tensorflow::Status::OK();
#undef LOG_AND_RETURN_IF_ERROR
@ -551,6 +552,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
status->status = EnableCollectiveOps(server_def, ctx);
}
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status) {
tensorflow::EagerContext* context =
tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
auto collective_executor_handle = context->GetCollectiveExecutorHandle();
collective_executor_handle->get()->StartAbort(status->status);
}
TF_ShapeAndTypeList* TF_NewShapeAndTypeList(int num_items) {
TF_ShapeAndTypeList* result = new TF_ShapeAndTypeList;
result->num_items = num_items;

View File

@ -230,6 +230,14 @@ TF_CAPI_EXPORT extern void TFE_EnableCollectiveOps(TFE_Context* ctx,
size_t proto_len,
TF_Status* status);
// Aborts all ongoing collectives with the specified status. After abortion,
// subsequent collectives will error with this status immediately.
//
// This is intended to be used when a peer failure is detected. There's yet no
// way to reset the collectives other than restarting the program.
TF_CAPI_EXPORT extern void TFE_AbortCollectiveOps(TFE_Context* ctx,
TF_Status* status);
// Information about the shape of a Tensor and its type.
struct TF_ShapeAndType {
// Number of dimensions. -1 indicates unknown rank.

View File

@ -240,6 +240,8 @@ tf_cuda_cc_test(
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
@ -308,6 +310,8 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/util:abstract_stack_trace",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
)
@ -514,7 +518,6 @@ tf_cuda_cc_test(
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",
@ -581,7 +584,6 @@ tf_cuda_cc_test(
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
"noasan", # leaks gRPC server instances
],
deps = [
":c_api",

View File

@ -94,7 +94,6 @@ limitations under the License.
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/public/version.h"
using tensorflow::int64;
using tensorflow::string;
namespace {
@ -968,7 +967,7 @@ int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
return -1;
}
int64 num_elements = -1;
tensorflow::int64 num_elements = -1;
status->status = tensorflow::unwrap(h)->NumElements(&num_elements);
return num_elements;
}
@ -980,7 +979,7 @@ int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
return -1;
}
int64 dim = -1;
tensorflow::int64 dim = -1;
status->status = tensorflow::unwrap(h)->Dim(dim_index, &dim);
return dim;
}

View File

@ -174,9 +174,9 @@ void TestFunctionWithPackedInput(const bool remote) {
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
// Create one variable per task.
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task0_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task1_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task2_name);
TFE_TensorHandle* h0 = TestVariable(ctx, 1.0, task1_name);
TFE_TensorHandle* h1 = TestVariable(ctx, 2.0, task2_name);
TFE_TensorHandle* h2 = TestVariable(ctx, 3.0, task0_name);
// Add a sync point in order to make sure that variables have been initialized
// before the function execution starts.
@ -185,6 +185,9 @@ void TestFunctionWithPackedInput(const bool remote) {
VarIsInitialized(ctx, h2);
// Pack 3 variable handles into one TFE_TensorHandle.
// When remote is false, function device is placed on task0. Handle types are
// REMOTE, REMOTE, LOCAL on task0. When remote is true, function device is
// placed on task1, Handle types are LOCAL, REMOTE, LOCAL on task1.
int num_replicas = 3;
std::vector<TFE_TensorHandle*> handles = {h0, h1, h2};
TFE_TensorHandle* packed_handle =

View File

@ -88,6 +88,20 @@ TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx) {
return th;
}
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx) {
constexpr int64_t dims[] = {100, 100};
constexpr int num_elements = dims[0] * dims[1];

View File

@ -34,6 +34,12 @@ TFE_TensorHandle* DoubleTestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing a 2x2 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle(TFE_Context* ctx);
// Return a tensor handle containing 2D matrix containing given data and
// dimensions
TFE_TensorHandle* TestMatrixTensorHandleWithInput(TFE_Context* ctx,
float data[], int64_t dims[],
int num_dims);
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);

View File

@ -33,6 +33,7 @@ limitations under the License.
using tensorflow::dyn_cast;
using tensorflow::string;
using tensorflow::gtl::ArraySlice;
namespace tensorflow {
namespace tracing {
@ -138,20 +139,23 @@ class GraphOperation : public TracingOperation {
Status SetAttrString(const char* attr_name, const char* data,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrString has not been implemented yet.");
tensorflow::StringPiece s(data, length);
op_->node_builder.Attr(attr_name, s);
return Status::OK();
}
Status SetAttrInt(const char* attr_name, int64_t value) override {
return tensorflow::errors::Unimplemented(
"SetAttrInt has not been implemented yet.");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
op_->node_builder.Attr(attr_name, static_cast<tensorflow::int64>(value));
return Status::OK();
}
Status SetAttrFloat(const char* attr_name, float value) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloat has not been implemented yet.");
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrBool(const char* attr_name, bool value) override {
return tensorflow::errors::Unimplemented(
"SetAttrBool has not been implemented yet.");
op_->node_builder.Attr(attr_name, value);
return Status::OK();
}
Status SetAttrType(const char* const attr_name, DataType value) override {
if (!op_) {
@ -164,8 +168,15 @@ class GraphOperation : public TracingOperation {
}
Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) override {
return tensorflow::errors::Unimplemented(
"SetAttrShape has not been implemented yet.");
PartialTensorShape shape;
if (num_dims >= 0) {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shape = PartialTensorShape(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims), num_dims));
}
op_->node_builder.Attr(attr_name, shape);
return Status::OK();
}
Status SetAttrFunction(const char* attr_name,
const AbstractOperation* value) override {
@ -174,8 +185,10 @@ class GraphOperation : public TracingOperation {
}
Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) override {
return tensorflow::errors::Unimplemented(
"SetAttrFunctionName has not been implemented yet.");
tensorflow::NameAttrList func_name;
func_name.set_name(string(value, value + length));
op_->node_builder.Attr(attr_name, func_name);
return Status::OK();
}
Status SetAttrTensor(const char* attr_name,
AbstractTensorInterface* tensor) override {
@ -184,33 +197,71 @@ class GraphOperation : public TracingOperation {
}
Status SetAttrStringList(const char* attr_name, const void* const* values,
const size_t* lengths, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrStringList has not been implemented yet.");
if (strcmp(attr_name, tensorflow::kColocationAttrName) == 0) {
op_->colocation_constraints.clear();
for (int i = 0; i < num_values; ++i) {
op_->colocation_constraints.emplace(static_cast<const char*>(values[i]),
lengths[i]);
}
} else {
std::vector<tensorflow::StringPiece> v;
v.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
v.emplace_back(static_cast<const char*>(values[i]), lengths[i]);
}
op_->node_builder.Attr(attr_name, v);
}
return Status::OK();
}
Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrFloatList has not been implemented yet.");
op_->node_builder.Attr(attr_name,
ArraySlice<const float>(values, num_values));
return Status::OK();
}
Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrIntList has not been implemented yet.");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
op_->node_builder.Attr(
attr_name,
ArraySlice<const tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(values), num_values));
return Status::OK();
}
Status SetAttrTypeList(const char* attr_name, const DataType* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrTypeList has not been implemented yet.");
op_->node_builder.Attr(attr_name,
ArraySlice<const DataType>(values, num_values));
return Status::OK();
}
Status SetAttrBoolList(const char* attr_name, const unsigned char* values,
int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrBoolList has not been implemented yet.");
std::unique_ptr<bool[]> b(new bool[num_values]);
for (int i = 0; i < num_values; ++i) {
b[i] = values[i];
}
op_->node_builder.Attr(attr_name,
ArraySlice<const bool>(b.get(), num_values));
return Status::OK();
}
Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) override {
return tensorflow::errors::Unimplemented(
"SetAttrShapeList has not been implemented yet.");
std::vector<PartialTensorShape> shapes;
shapes.reserve(num_values);
for (int i = 0; i < num_values; ++i) {
if (num_dims[i] < 0) {
shapes.emplace_back();
} else {
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
shapes.emplace_back(ArraySlice<tensorflow::int64>(
reinterpret_cast<const tensorflow::int64*>(dims[i]), num_dims[i]));
}
}
op_->node_builder.Attr(attr_name, shapes);
return Status::OK();
}
Status SetAttrFunctionList(
const char* attr_name,

View File

@ -92,9 +92,255 @@ TEST_P(UnifiedCAPI, TestBasicEager) {
TF_DeleteExecutionContext(ctx);
}
// MatMul Test
TEST_P(UnifiedCAPI, TestBasicEagerMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Want to test simple MatMul example:
[[0,0], * [[0,0], = [[0,0],
[0,0]] [0,0]] [0,0]]
*/
// Build an abstract input tensor.
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
float vals[] = {0.0f, 0.0f, 0.0f, 0.0f};
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
TFE_TensorHandle* t =
TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims);
TF_AbstractTensor* at = TF_CreateAbstractTensorFromEagerTensor(
t, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "MatMul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
// Copy Tensor data into an array.
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(result_tensor),
TF_TensorByteSize(result_tensor));
int data_len = 4; // length of result_data
for (int i = 0; i < data_len; i++) {
EXPECT_EQ(result_data[i], 0);
}
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
}
// MatMul Test 2
TEST_P(UnifiedCAPI, TestBasicEagerMatMul2) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Want to test simple MatMul example with abstract tensors:
[[1,2], * [[5,6], = [[19,22],
[3,4]] [7,8]] [43,50]]
*/
// Build 1st Matrix.
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f};
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
TFE_TensorHandle* t1 =
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor(
t1, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build 2nd Matrix.
float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f};
TFE_TensorHandle* t2 =
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor(
t2, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "MatMul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at1, at2};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at1);
TF_DeleteAbstractTensor(at2);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
// Copy Tensor data into array.
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(result_tensor),
TF_TensorByteSize(result_tensor));
// Build expected result & verify.
float e_vals[] = {19.0f, 22.0f, 43.0f, 50.0f};
int data_len = 4; // length of e_vals
for (int i = 0; i < data_len; i++) {
EXPECT_EQ(result_data[i], e_vals[i]);
}
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
}
// MatAdd
TEST_P(UnifiedCAPI, TestBasicEagerMatAdd) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* ctx = TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Want to test simple MatAdd example with abstract tensors:
[[1,2] , + [[5,6], = [[6,8],
[3,4] ] [7,8] ] [10,12]]
*/
// Build 1st Matrix.
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
float vals1[] = {1.0f, 2.0f, 3.0f, 4.0f};
TFE_Context* eager_ctx = TF_ExecutionContextGetTFEContext(ctx, status.get());
TFE_TensorHandle* t1 =
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
TF_AbstractTensor* at1 = TF_CreateAbstractTensorFromEagerTensor(
t1, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build 2nd Matrix.
float vals2[] = {5.0f, 6.0f, 7.0f, 8.0f};
TFE_TensorHandle* t2 =
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
TF_AbstractTensor* at2 = TF_CreateAbstractTensorFromEagerTensor(
t2, status.get()); // get abstract tensor
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp(ctx);
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at1, at2};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at1);
TF_DeleteAbstractTensor(at2);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
// Copy Tensor data into array.
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(result_tensor),
TF_TensorByteSize(result_tensor));
// Build expected result & verify.
float e_vals[] = {6.0f, 8.0f, 10.0f, 12.0f};
int data_len = 4; // length of e_vals
for (int i = 0; i < data_len; i++) {
EXPECT_EQ(result_data[i], e_vals[i]);
}
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TF_DeleteOutputList(o);
TF_DeleteExecutionContext(ctx);
}
TEST_P(UnifiedCAPI, TestBasicGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Start a new function / execution context.
string fn_name = "double";
TF_ExecutionContext* graph_ctx =
@ -142,6 +388,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
@ -180,6 +427,111 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
TF_DeleteExecutionContext(eager_execution_ctx);
}
// Graph Tracing for MatMul
TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Start a new function / execution context.
string fn_name = "matrix_multiply";
TF_ExecutionContext* graph_ctx =
TF_CreateFunction(fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* matmul_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(matmul_op, "MatMul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(matmul_op, "my_matmul", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {placeholder_t, placeholder_t};
TF_OutputList* mm_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(matmul_op, 2, inputs, mm_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(matmul_op);
TF_AbstractFunction* func =
TF_FinalizeFunction(graph_ctx, mm_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
/* Now that the graph is built, test graph implementation on matmul example:
[[1,1] , * [[1,1] , = [[2,2],
[1,1]] [1,1]] [2,2]]
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx, status.get());
float vals[] = {1.0f, 1.0f, 1.0f, 1.0f};
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
TFE_TensorHandle* input_eager =
TestMatrixTensorHandleWithInput(eager_ctx, vals, dims, num_dims);
TF_AbstractTensor* input_t =
TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, mm_outputs, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(mm_outputs));
TF_AbstractTensor* final_result = TF_OutputListGet(mm_outputs, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t));
int data_len = 4;
for (int i = 0; i < data_len; i++) {
ASSERT_EQ(result_data[i], 2.0f);
}
TF_DeleteAbstractTensor(final_result);
TF_DeleteOutputList(mm_outputs);
TF_DeleteAbstractTensor(placeholder_t);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TF_DeleteTensor(f_t);
TF_DeleteAbstractFunction(func);
TF_DeleteExecutionContext(eager_execution_ctx);
}
TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -336,6 +688,217 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_DeleteAbstractFunction(func);
}
TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Status* s = status.get();
// Start a new function / execution context.
string fn_name = "two_adds_and_matmul";
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
TF_AbstractTensor* add_output1;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add1", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg0, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output1 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// Same with a second "Add" computing `arg1 + arg1`.
TF_AbstractTensor* add_output2;
{
// Build an abstract operation, inputs and output.
auto* add_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(add_op, "Add", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(add_op, "my_add2", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {arg1, arg1};
TF_OutputList* add_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(add_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(add_op);
// Extract the resulting tensor.
add_output2 = TF_OutputListGet(add_outputs, 0);
TF_DeleteOutputList(add_outputs);
}
// 3rd Output will be Matrix Multiplication of add_output1 and add_output2
TF_AbstractTensor* mm_output;
{
// Build an abstract operation, inputs and output.
auto* mm_op = TF_NewAbstractOp(graph_ctx);
TF_AbstractOpSetOpType(mm_op, "MatMul", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractOpSetOpName(mm_op, "mm", s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_AbstractTensor* inputs[2] = {add_output1, add_output2};
TF_OutputList* mm_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(mm_outputs, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Trace the operation now (create a node in the graph).
TF_ExecuteOperation(mm_op, 2, inputs, mm_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(mm_op);
// Extract the resulting tensor.
mm_output = TF_OutputListGet(mm_outputs, 0);
TF_DeleteOutputList(mm_outputs);
}
// Finalize the function by providing the returned values.
TF_AbstractFunction* func;
{
// We want to return the output of both add operations and MatMul operation,
// create a new list and populate it.
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListPushBack(func_outputs, add_output1, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, add_output2, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_OutputListPushBack(func_outputs, mm_output, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
func = TF_FinalizeFunction(graph_ctx, func_outputs, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteOutputList(func_outputs);
}
/**
* We traced so far this function:
*
* def two_adds_and_mm(A, B):
* my_add1 = A + B
* my_add2 = B + B
* mm = tf.MatMul(my_add1,my_add2)
* return my_add1, my_add2, mm
*
* Now we will execute this function with an eager context:
*
* A =[[0, 1],[1, 0]]
* B =[[1, 0],[0, 1]]
*
* output1, output2, output3 = two_adds_and_mm(A, B)
*
* We expect outputs:
*
* output1 = [[1, 1],[1, 1]]
* output2 = [[2, 0],[0, 2]]
* output3 = [[2, 2],[2, 2]]
*
*/
// Build eager context.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TF_ExecutionContext* eager_execution_ctx =
TF_NewEagerExecutionContext(opts, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TFE_DeleteContextOptions(opts);
TF_ExecutionContextRegisterFunction(eager_execution_ctx, func, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build the abstract op to run the function.
TF_AbstractOp* fn_op = TF_NewAbstractOp(eager_execution_ctx);
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Build two abstract input tensors as function arguments.
std::vector<TF_AbstractTensor*> func_args;
{
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(eager_execution_ctx, s);
// 1st Arg
float vals1[] = {0.0f, 1.0f, 1.0f, 0.0f};
int64_t dims[] = {2, 2}; // Matrices will be 2 x 2
int num_dims = sizeof(dims) / sizeof(dims[0]);
TFE_TensorHandle* input_eager =
TestMatrixTensorHandleWithInput(eager_ctx, vals1, dims, num_dims);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// 2nd Arg
float vals2[] = {1.0f, 0.0f, 0.0f, 1.0f};
input_eager =
TestMatrixTensorHandleWithInput(eager_ctx, vals2, dims, num_dims);
func_args.push_back(TF_CreateAbstractTensorFromEagerTensor(input_eager, s));
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
}
TF_OutputList* func_outputs = TF_NewOutputList();
TF_OutputListSetNumOutputs(func_outputs, 3, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_ExecuteOperation(fn_op, func_args.size(), func_args.data(), func_outputs,
s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_DeleteAbstractOp(fn_op);
for (TF_AbstractTensor* t : func_args) TF_DeleteAbstractTensor(t);
ASSERT_EQ(3, TF_OutputListNumOutputs(func_outputs));
float expected_outputs[3][4] = {{1.0f, 1.0f, 1.0f, 1.0f},
{2.0f, 0.0f, 0.0f, 2.0f},
{2.0f, 2.0f, 2.0f, 2.0f}};
float result_data[4];
for (int idx = 0; idx < 3; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TFE_TensorHandle* handle = TF_AbstractTensorGetEagerTensor(result, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
TF_Tensor* f_t = TFE_TensorHandleResolve(handle, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
memcpy(&result_data[0], TF_TensorData(f_t), TF_TensorByteSize(f_t));
// Verify results for each output
for (int j = 0; j < 4; j++) {
ASSERT_EQ(result_data[j], expected_outputs[idx][j]);
}
TF_DeleteTensor(f_t);
}
// Free memory associated with add and MatMul outputs
for (int idx = 0; idx < 3; ++idx) {
TF_AbstractTensor* result = TF_OutputListGet(func_outputs, idx);
TF_DeleteAbstractTensor(result);
}
TF_DeleteOutputList(func_outputs);
TF_DeleteExecutionContext(eager_execution_ctx);
TF_DeleteAbstractFunction(func);
}
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -175,7 +175,8 @@ Status TapeVSpace::CallBackwardFunction(
gtl::ArraySlice<AbstractTensorHandle*> output_gradients,
std::vector<AbstractTensorHandle*>* result) const {
if (backward_function == nullptr) return Status::OK();
return backward_function->Compute(output_gradients, result);
Context ctx = {ctx_};
return backward_function->Compute(&ctx, output_gradients, result);
}
// Looks up the ID of a Gradient.

View File

@ -31,7 +31,8 @@ namespace gradients {
//
// class AddGradientFunction : public GradientFunction {
// public:
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
// Status Compute(Context* ctx,
// absl::Span<AbstractTensorHandle* const> grad_inputs,
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
// grad_outputs->resize(2);
// (*grad_outputs)[0] = grad_inputs[0];
@ -50,11 +51,16 @@ namespace gradients {
// Status RegisterGradients(GradientRegistry* registry) {
// return registry->Register("Add", AddRegisterer);
// }
struct Context {
public:
AbstractContext* ctx;
};
class GradientFunction {
public:
// TODO(srbs): How we support CompositeTensors e.g. IndexedSlices in
// `grad_inputs`.
virtual Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
virtual Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) = 0;
virtual ~GradientFunction() {}
};

View File

@ -23,6 +23,8 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
@ -42,55 +44,10 @@ class CppGradients
}
};
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
return Status::OK();
}
// =================== Register gradients for Add ============================
class AddGradientFunction : public GradientFunction {
public:
explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
absl::MakeSpan(identity_outputs), "Id1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
private:
AbstractContext* ctx_;
};
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction(op.ctx);
}
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
}
// =================== End gradient registrations ============================
// Computes `inputs[0] + inputs[1]` and records it on the tape.
Status Add(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <memory>
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
@ -26,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/util/abstract_stack_trace.h"
struct TFE_Op;
@ -44,6 +46,12 @@ class ImmediateExecutionOperation : public AbstractOperation {
// Experimental
virtual Status SetUseXla(bool enable) = 0;
// Set stack trace to be used for potential async error reporting.
virtual void SetStackTrace(AbstractStackTrace stack_trace) = 0;
// Returns the stack trace set by `SetStackTrace` if exists.
virtual absl::optional<AbstractStackTrace> GetStackTrace() = 0;
// For LLVM style RTTI.
static bool classof(const AbstractOperation* ptr) {
return ptr->getKind() == kEager || ptr->getKind() == kTfrt;

View File

@ -78,6 +78,11 @@ typedef struct TF_Filesystem {
void* plugin_filesystem;
} TF_Filesystem;
typedef struct TF_TransactionToken {
void* token;
TF_Filesystem* owner;
} TF_TransactionToken;
/// SECTION 2. Function tables for functionality provided by plugins
/// ----------------------------------------------------------------------------
///
@ -679,6 +684,133 @@ typedef struct TF_FilesystemOps {
///
/// DEFAULT IMPLEMENTATION: No op.
void (*flush_caches)(const TF_Filesystem* filesystem);
/// Starts a new transaction.
///
/// An opaque transaction token is returned in `token`. Ownership of the token
/// is in filesystem. Token will be freed in `end_transaction` call and any
/// access to token after that is invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if transaction successfuly started.
/// * Must set `status` to `TF_FAILED_PRECONDITION` if multiple transactions
/// are not supported
/// * Might use any other error value for `status` to signal other errors.
int (*start_transaction)(const TF_Filesystem* filesystem,
TF_TransactionToken** token, TF_Status* status);
/// Ends transaction and free the `token`. Any access to token after
/// that will be invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if transaction successfuly finalized.
/// * Must set `status` to `TF_NOT_FOUND` if token is invalid/not found
/// * Might use any other error value for `status` to signal other errors.
int (*end_transaction)(const TF_Filesystem* filesystem,
TF_TransactionToken* token, TF_Status* status);
/// Adds file/directory in the `path` to transaction in `token`. It is a valid
/// operation to add a path that doesn't exist yet to a transaction.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if path added to transaction successful.
/// * Must set `status` to `TF_NOT_FOUND` if `token` is invalid.
/// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is in
/// another transaction and multiple transactions are not supported
/// * Might use any other error value for `status` to signal other errors.
int (*add_to_transaction)(const TF_Filesystem* filesystem, const char* path,
TF_TransactionToken* token, TF_Status* status);
/// Returns transaction token for file/directory in the `path`. Note that path
/// may not exist yet but still might be part of a transaction.
///
/// Transaction token is returned in `token`. Ownership of the token is in
/// filesystem. Token will be freed in `end_transaction` call and any access
/// to token after that is invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if a transaction for path is found
/// * Must set `status` to `TF_NOT_FOUND` if `path` is not part of any
/// transaction
/// * Must set `status` to `TF_FAILED_PRECONDITION` if `path` is
/// not in this filesystem.
/// * Might use any other error value for `status` to signal other errors.
int (*get_transaction_for_path)(const TF_Filesystem* filesystem,
const char* path, TF_TransactionToken** token,
TF_Status* status);
/// Returns transaction token for `path` if it is part of a transaction else
/// starts a new transaction and adds `path` to that transaction
///
/// Transaction token is returned in `token`. Ownership of the token is in
/// filesystem. Token will be freed in `end_transaction` call and any access
/// to token after that is invalid.
///
/// In case of error, plugins must set `status` to a value different than
/// `TF_OK`, free memory allocated for `token` and return -1.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// Plugins:
/// * Must set `status` to `TF_OK` if transaction found or successfuly
/// started.
/// * Must set `status` to `TF_NOT_FOUND` if `path` doesn't point to this
/// filesystem
/// * Must set `status` to `TF_FAILED_PRECONDITION` if file/directory is
/// not in any transaction and multiple transactions are not supported.
/// * Might use any other error value for `status` to signal other errors.
int (*get_or_start_transaction_for_path)(const TF_Filesystem* filesystem,
const char* path,
TF_TransactionToken** token,
TF_Status* status);
/// Decodes transaction token in `token` to human readable format for
/// debugging.
///
/// A new `char*` buffer must be allocated by this method. Core TensorFlow
/// manages the lifetime of the buffer after the call. Thus, all callers of
/// this method must take ownership of the returned pointer.
///
/// Plugins must not return `nullptr`. Returning empty strings is allowed.
///
/// The allocation and freeing of memory must happen via the functions sent to
/// core TensorFlow upon registration (see the `TF_FilesystemPluginInfo`
/// structure in Section 4).
///
/// DEFAULT IMPLEMENTATION: Dump token and owner address.
char* (*decode_transaction_token)(const TF_Filesystem* filesystem,
const TF_TransactionToken* token);
} TF_FilesystemOps;
// LINT.ThenChange(:filesystem_ops_version)

View File

@ -35,7 +35,8 @@ using UniquePtrTo_TF_Status =
::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
Status ModularFileSystem::NewRandomAccessFile(
const std::string& fname, std::unique_ptr<RandomAccessFile>* result) {
const std::string& fname,
std::unique_ptr<RandomAccessFile>* result /*, TransactionToken* token */) {
if (ops_->new_random_access_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewRandomAccessFile()"));
@ -54,7 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile(
}
Status ModularFileSystem::NewWritableFile(
const std::string& fname, std::unique_ptr<WritableFile>* result) {
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
if (ops_->new_writable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewWritableFile()"));
@ -73,7 +75,8 @@ Status ModularFileSystem::NewWritableFile(
}
Status ModularFileSystem::NewAppendableFile(
const std::string& fname, std::unique_ptr<WritableFile>* result) {
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
if (ops_->new_appendable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewAppendableFile()"));
@ -92,7 +95,8 @@ Status ModularFileSystem::NewAppendableFile(
}
Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>* result) {
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token */) {
if (ops_->new_read_only_memory_region_from_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname,
@ -112,7 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::FileExists(const std::string& fname) {
Status ModularFileSystem::FileExists(
const std::string& fname /*, TransactionToken* token */) {
if (ops_->path_exists == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support FileExists()"));
@ -124,8 +129,9 @@ Status ModularFileSystem::FileExists(const std::string& fname) {
return StatusFromTF_Status(plugin_status.get());
}
bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
std::vector<Status>* status) {
bool ModularFileSystem::FilesExist(
const std::vector<std::string>& files,
std::vector<Status>* status /*, TransactionToken* token */) {
if (ops_->paths_exist == nullptr)
return FileSystem::FilesExist(files, status);
@ -156,8 +162,9 @@ bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
return result;
}
Status ModularFileSystem::GetChildren(const std::string& dir,
std::vector<std::string>* result) {
Status ModularFileSystem::GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token */) {
if (ops_->get_children == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dir, " does not support GetChildren()"));
@ -181,8 +188,9 @@ Status ModularFileSystem::GetChildren(const std::string& dir,
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
std::vector<std::string>* result) {
Status ModularFileSystem::GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>* result /*, TransactionToken* token */) {
if (ops_->get_matching_paths == nullptr)
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
@ -203,7 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteFile(const std::string& fname) {
Status ModularFileSystem::DeleteFile(
const std::string& fname /*, TransactionToken* token */) {
if (ops_->delete_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support DeleteFile()"));
@ -215,9 +224,9 @@ Status ModularFileSystem::DeleteFile(const std::string& fname) {
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
int64* undeleted_files,
int64* undeleted_dirs) {
Status ModularFileSystem::DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token */) {
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
return errors::FailedPrecondition(
"DeleteRecursively must not be called with `undeleted_files` or "
@ -238,7 +247,8 @@ Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteDir(const std::string& dirname) {
Status ModularFileSystem::DeleteDir(
const std::string& dirname /*, TransactionToken* token */) {
if (ops_->delete_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support DeleteDir()"));
@ -250,7 +260,8 @@ Status ModularFileSystem::DeleteDir(const std::string& dirname) {
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
Status ModularFileSystem::RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token */) {
if (ops_->recursively_create_dir == nullptr)
return FileSystem::RecursivelyCreateDir(dirname);
@ -261,7 +272,8 @@ Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname) {
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CreateDir(const std::string& dirname) {
Status ModularFileSystem::CreateDir(
const std::string& dirname /*, TransactionToken* token */) {
if (ops_->create_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support CreateDir()"));
@ -273,7 +285,9 @@ Status ModularFileSystem::CreateDir(const std::string& dirname) {
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) {
Status ModularFileSystem::Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token */) {
if (ops_->stat == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support Stat()"));
@ -296,7 +310,8 @@ Status ModularFileSystem::Stat(const std::string& fname, FileStatistics* stat) {
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::IsDirectory(const std::string& name) {
Status ModularFileSystem::IsDirectory(
const std::string& name /*, TransactionToken* token */) {
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -306,8 +321,9 @@ Status ModularFileSystem::IsDirectory(const std::string& name) {
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetFileSize(const std::string& fname,
uint64* file_size) {
Status ModularFileSystem::GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token */) {
if (ops_->get_file_size == nullptr) {
FileStatistics stat;
Status status = Stat(fname, &stat);
@ -326,8 +342,9 @@ Status ModularFileSystem::GetFileSize(const std::string& fname,
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RenameFile(const std::string& src,
const std::string& target) {
Status ModularFileSystem::RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
if (ops_->rename_file == nullptr) {
Status status = CopyFile(src, target);
if (status.ok()) status = DeleteFile(src);
@ -342,8 +359,9 @@ Status ModularFileSystem::RenameFile(const std::string& src,
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CopyFile(const std::string& src,
const std::string& target) {
Status ModularFileSystem::CopyFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -354,7 +372,8 @@ Status ModularFileSystem::CopyFile(const std::string& src,
return StatusFromTF_Status(plugin_status.get());
}
std::string ModularFileSystem::TranslateName(const std::string& name) const {
std::string ModularFileSystem::TranslateName(
const std::string& name /*, TransactionToken* token */) const {
if (ops_->translate_name == nullptr) return FileSystem::TranslateName(name);
char* p = ops_->translate_name(filesystem_.get(), name.c_str());
@ -366,7 +385,7 @@ std::string ModularFileSystem::TranslateName(const std::string& name) const {
return ret;
}
void ModularFileSystem::FlushCaches() {
void ModularFileSystem::FlushCaches(/*TransactionToken* token*/) {
if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get());
}

View File

@ -61,34 +61,69 @@ class ModularFileSystem final : public FileSystem {
Status NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>* result) override;
Status NewWritableFile(const std::string& fname,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const std::string& fname,
std::unique_ptr<WritableFile>* result) override;
std::unique_ptr<RandomAccessFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewReadOnlyMemoryRegionFromFile(
const std::string& fname,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status FileExists(const std::string& fname) override;
std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token = nullptr */) override;
Status FileExists(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
bool FilesExist(const std::vector<std::string>& files,
std::vector<Status>* status) override;
Status GetChildren(const std::string& dir,
std::vector<std::string>* result) override;
Status GetMatchingPaths(const std::string& pattern,
std::vector<std::string>* results) override;
Status DeleteFile(const std::string& fname) override;
Status DeleteRecursively(const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs) override;
Status DeleteDir(const std::string& dirname) override;
Status RecursivelyCreateDir(const std::string& dirname) override;
Status CreateDir(const std::string& dirname) override;
Status Stat(const std::string& fname, FileStatistics* stat) override;
Status IsDirectory(const std::string& fname) override;
Status GetFileSize(const std::string& fname, uint64* file_size) override;
Status RenameFile(const std::string& src, const std::string& target) override;
Status CopyFile(const std::string& src, const std::string& target) override;
std::string TranslateName(const std::string& name) const override;
void FlushCaches() override;
std::vector<Status>*
status /*, TransactionToken* token = nullptr */) override;
Status GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token = nullptr */)
override;
Status GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>*
results /*, TransactionToken* token = nullptr */) override;
Status DeleteFile(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token = nullptr */) override;
Status DeleteDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status CreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token = nullptr */) override;
Status IsDirectory(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token = nullptr */) override;
Status RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token = nullptr */)
override;
Status CopyFile(const std::string& src,
const std::string&
target /*, TransactionToken* token = nullptr */) override;
std::string TranslateName(
const std::string& name /*, TransactionToken* token = nullptr */)
const override;
void FlushCaches(/* TransactionToken* token=nullptr */) override;
private:
std::unique_ptr<TF_Filesystem> filesystem_;

View File

@ -25,7 +25,9 @@ cc_library(
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":expiring_lru_cache",
":gcs_helper",
":ram_file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
@ -44,14 +46,6 @@ cc_library(
],
)
cc_library(
name = "file_block_cache",
hdrs = ["file_block_cache.h"],
deps = [
"//tensorflow/c:tf_status",
],
)
cc_library(
name = "cleanup",
hdrs = ["cleanup.h"],
@ -63,7 +57,6 @@ cc_library(
hdrs = ["ram_file_block_cache.h"],
deps = [
":cleanup",
":file_block_cache",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"@com_google_absl//absl/base:core_headers",

View File

@ -1,140 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_
#include <functional>
#include <iostream>
#include <list>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem {
class FileBlockCache;
/// FileBlockCacheStatsInterface allows for instrumentation of the block cache.
///
/// FileBlockCacheStatsInterface and its subclasses must be safe to use from
/// multiple threads concurrently.
///
/// WARNING! This is an experimental interface that may change or go away at any
/// time.
class FileBlockCacheStatsInterface {
public:
/// Configure is called to provide instrumentation hooks.
///
/// Note: Configure can be called multiple times (e.g. if the block cache is
/// re-initialized).
virtual void Configure(const FileBlockCache* block_cache) = 0;
/// RecordBlockLoadRequest is called to record the size of a hit block.
virtual void RecordCacheHitBlockSize(size_t bytes_transferred) = 0;
/// RecordBlockLoadRequest is called to record the size of a missed block.
virtual void RecordCacheMissBlockSize(size_t bytes_transferred) = 0;
virtual ~FileBlockCacheStatsInterface() = default;
};
/// \brief A block cache of file contents, keyed by {filename, offset}.
///
/// This class should be shared by read-only random access files on a remote
/// filesystem (e.g. GCS).
class FileBlockCache {
public:
/// The callback executed when a block is not found in the cache, and needs to
/// be fetched from the backing filesystem. This callback is provided when the
/// cache is constructed. The `status` should be `TF_OK` as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
typedef std::function<void(const std::string& filename, size_t offset,
size_t buffer_size, char* buffer,
size_t* bytes_transferred, TF_Status* status)>
BlockFetcher;
virtual ~FileBlockCache() {}
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
/// method will set `status` to:
///
/// 1) The error from the remote filesystem, if the read from the remote
/// filesystem failed.
/// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem
/// succeeded,
/// but the read returned a partial block, and the LRU cache contained a
/// block at a higher offset (indicating that the partial block should have
/// been a full block).
/// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but
/// the file contents do not extend past `offset` and thus nothing was
/// placed in `out`.
/// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was
/// placed
/// in `buffer`).
///
/// Caller is responsible for allocating memory for `buffer`.
/// `buffer` will be left unchanged in case of errors.
virtual void Read(const std::string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) = 0;
// Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file did not
// exist before. If the signature changes, update the existing signature with
// the new one and remove the file from cache.
virtual bool ValidateAndUpdateFileSignature(const std::string& filename,
int64_t file_signature) = 0;
/// Remove all cached blocks for `filename`.
virtual void RemoveFile(const std::string& filename) = 0;
/// Remove all cached data.
virtual void Flush() = 0;
/// Accessors for cache parameters.
virtual size_t block_size() const = 0;
virtual size_t max_bytes() const = 0;
virtual uint64_t max_staleness() const = 0;
/// The current size (in bytes) of the cache.
virtual size_t CacheSize() const = 0;
// Returns true if the cache is enabled. If false, the BlockFetcher callback
// is always executed during Read.
virtual bool IsCacheEnabled() const = 0;
void SetStats(FileBlockCacheStatsInterface* stats) {
if (stats == nullptr) {
std::cerr
<< "Attempted to monitor a NULL stats object. This may prevent the "
"corresponding monitoring data from being exported";
return;
}
cache_stats_ = stats;
cache_stats_->Configure(this);
}
protected:
FileBlockCacheStatsInterface* cache_stats_ = nullptr; // Not owned.
};
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_FILE_BLOCK_CACHE_H_

View File

@ -28,6 +28,27 @@ limitations under the License.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
// The environment variable that overrides the block size for aligned reads from
// GCS. Specified in MB (e.g. "16" = 16 x 1024 x 1024 = 16777216 bytes).
constexpr char kBlockSize[] = "GCS_READ_CACHE_BLOCK_SIZE_MB";
constexpr size_t kDefaultBlockSize = 64 * 1024 * 1024;
// The environment variable that overrides the max size of the LRU cache of
// blocks read from GCS. Specified in MB.
constexpr char kMaxCacheSize[] = "GCS_READ_CACHE_MAX_SIZE_MB";
constexpr size_t kDefaultMaxCacheSize = 0;
// The environment variable that overrides the maximum staleness of cached file
// contents. Once any block of a file reaches this staleness, all cached blocks
// will be evicted on the next read.
constexpr char kMaxStaleness[] = "GCS_READ_CACHE_MAX_STALENESS";
constexpr uint64_t kDefaultMaxStaleness = 0;
constexpr char kStatCacheMaxAge[] = "GCS_STAT_CACHE_MAX_AGE";
constexpr uint64_t kStatCacheDefaultMaxAge = 5;
// The environment variable that overrides the maximum number of entries in the
// Stat cache.
constexpr char kStatCacheMaxEntries[] = "GCS_STAT_CACHE_MAX_ENTRIES";
constexpr size_t kStatCacheDefaultMaxEntries = 1024;
// How to upload new data when Flush() is called multiple times.
// By default the entire file is reuploaded.
constexpr char kAppendMode[] = "GCS_APPEND_MODE";
@ -82,28 +103,16 @@ static void MaybeAppendSlash(std::string* name) {
name->push_back('/');
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct GCSFile {
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
} GCSFile;
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// TODO(vnvo2409): Adding cache.
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
auto stream = gcs_file->gcs_client->ReadObject(
gcs_file->bucket, gcs_file->object, gcs::ReadRange(offset, offset + n));
// A helper function to actually read the data from GCS.
static int64_t LoadBufferFromGCS(const std::string& path, size_t offset,
size_t buffer_size, char* buffer,
tf_gcs_filesystem::GCSFile* gcs_file,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return -1;
auto stream = gcs_file->gcs_client.ReadObject(
bucket, object, gcs::ReadRange(offset, offset + buffer_size));
TF_SetStatusFromGCSStatus(stream.status(), status);
if ((TF_GetCode(status) != TF_OK) &&
(TF_GetCode(status) != TF_OUT_OF_RANGE)) {
@ -112,16 +121,119 @@ int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
int64_t read;
if (!absl::SimpleAtoi(stream.headers().find("content-length")->second,
&read)) {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
if (read != n) {
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
// When we read a file with offset that is bigger than the actual file size.
// GCS will return an empty header (e.g no `content-length` header). In this
// case, we will set read to `0` and continue.
if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
read = 0;
} else {
TF_SetStatus(status, TF_UNKNOWN, "Could not get content-length header");
return -1;
}
}
// `TF_OUT_OF_RANGE` isn't considered as an error. So we clear it here.
TF_SetStatus(status, TF_OK, "");
stream.read(buffer, read);
read = stream.gcount();
if (read < buffer_size) {
// Check stat cache to see if we encountered an interrupted read.
tf_gcs_filesystem::GcsFileStat stat;
if (gcs_file->stat_cache->Lookup(path, &stat)) {
if (offset + read < stat.base.length) {
TF_SetStatus(status, TF_INTERNAL,
absl::StrCat("File contents are inconsistent for file: ",
path, " @ ", offset)
.c_str());
}
}
}
return read;
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
using ReadFn =
std::function<int64_t(const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status)>;
typedef struct GCSFile {
const std::string path;
const bool is_cache_enable;
const uint64_t buffer_size;
ReadFn read_fn;
absl::Mutex buffer_mutex;
uint64_t buffer_start ABSL_GUARDED_BY(buffer_mutex);
bool buffer_end_is_past_eof ABSL_GUARDED_BY(buffer_mutex);
std::string buffer ABSL_GUARDED_BY(buffer_mutex);
GCSFile(std::string path, bool is_cache_enable, uint64_t buffer_size,
ReadFn read_fn)
: path(path),
is_cache_enable(is_cache_enable),
buffer_size(buffer_size),
read_fn(std::move(read_fn)),
buffer_mutex(),
buffer_start(0),
buffer_end_is_past_eof(false),
buffer() {}
} GCSFile;
void Cleanup(TF_RandomAccessFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
delete gcs_file;
}
// `google-cloud-cpp` is working on a feature that we may want to use.
// See https://github.com/googleapis/google-cloud-cpp/issues/4013.
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
if (gcs_file->is_cache_enable || n > gcs_file->buffer_size) {
return gcs_file->read_fn(gcs_file->path, offset, n, buffer, status);
} else {
absl::MutexLock l(&gcs_file->buffer_mutex);
size_t buffer_end = gcs_file->buffer_start + gcs_file->buffer.size();
size_t copy_size = 0;
if (offset < buffer_end && gcs_file->buffer_start) {
copy_size = (std::min)(n, static_cast<size_t>(buffer_end - offset));
memcpy(buffer,
gcs_file->buffer.data() + (offset - gcs_file->buffer_start),
copy_size);
}
bool consumed_buffer_to_eof =
offset + copy_size >= buffer_end && gcs_file->buffer_end_is_past_eof;
if (copy_size < n && !consumed_buffer_to_eof) {
gcs_file->buffer_start = offset + copy_size;
gcs_file->buffer.resize(gcs_file->buffer_size);
auto read_fill_buffer = gcs_file->read_fn(
gcs_file->path, gcs_file->buffer_start, gcs_file->buffer_size,
&(gcs_file->buffer[0]), status);
gcs_file->buffer_end_is_past_eof =
(TF_GetCode(status) == TF_OUT_OF_RANGE);
if (read_fill_buffer >= 0) gcs_file->buffer.resize(read_fill_buffer);
if (TF_GetCode(status) != TF_OK &&
TF_GetCode(status) != TF_OUT_OF_RANGE) {
// Empty the buffer to avoid caching bad reads.
gcs_file->buffer.resize(0);
return -1;
}
size_t remaining_copy =
(std::min)(n - copy_size, gcs_file->buffer.size());
memcpy(buffer + copy_size, gcs_file->buffer.data(), remaining_copy);
copy_size += remaining_copy;
}
if (copy_size < n) {
// Forget the end-of-file flag to allow for clients that poll on the
// same file.
gcs_file->buffer_end_is_past_eof = false;
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
return copy_size;
}
TF_SetStatus(status, TF_OK, "");
return copy_size;
}
}
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
@ -290,11 +402,52 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
// TODO(vnvo2409): Use partial reponse for better performance.
// TODO(vnvo2409): We could do some cleanups like `return TF_SetStatus`.
// TODO(vnvo2409): Refactor the filesystem implementation when
// https://github.com/googleapis/google-cloud-cpp/issues/4482 is done.
GCSFile::GCSFile(google::cloud::storage::Client&& gcs_client)
: gcs_client(gcs_client), block_cache_lock() {
const char* append_mode = std::getenv(kAppendMode);
compose = (append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
uint64_t value;
block_size = kDefaultBlockSize;
size_t max_bytes = kDefaultMaxCacheSize;
uint64_t max_staleness = kDefaultMaxStaleness;
// Apply the overrides for the block size (MB), max bytes (MB), and max
// staleness (seconds) if provided.
if (absl::SimpleAtoi(std::getenv(kBlockSize), &value)) {
block_size = value * 1024 * 1024;
}
if (absl::SimpleAtoi(std::getenv(kMaxCacheSize), &value)) {
max_bytes = static_cast<size_t>(value * 1024 * 1024);
}
if (absl::SimpleAtoi(std::getenv(kMaxStaleness), &value)) {
max_staleness = value;
}
file_block_cache = std::make_unique<RamFileBlockCache>(
block_size, max_bytes, max_staleness,
[this](const std::string& filename, size_t offset, size_t buffer_size,
char* buffer, TF_Status* status) {
return LoadBufferFromGCS(filename, offset, buffer_size, buffer, this,
status);
});
uint64_t stat_cache_max_age = kStatCacheDefaultMaxAge;
size_t stat_cache_max_entries = kStatCacheDefaultMaxEntries;
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxAge), &value)) {
stat_cache_max_age = value;
}
if (absl::SimpleAtoi(std::getenv(kStatCacheMaxEntries), &value)) {
stat_cache_max_entries = static_cast<size_t>(value);
}
stat_cache = std::make_unique<ExpiringLRUCache<GcsFileStat>>(
stat_cache_max_age, stat_cache_max_entries);
}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
@ -303,12 +456,7 @@ void Init(TF_Filesystem* filesystem, TF_Status* status) {
return;
}
const char* append_mode = std::getenv(kAppendMode);
bool compose =
(append_mode != nullptr) && (!strcmp(kAppendMode, append_mode));
filesystem->plugin_filesystem =
new GCSFile({std::move(client.value()), compose});
filesystem->plugin_filesystem = new GCSFile(std::move(client.value()));
TF_SetStatus(status, TF_OK, "");
}
@ -317,6 +465,19 @@ void Cleanup(TF_Filesystem* filesystem) {
delete gcs_file;
}
static void UncachedStatForObject(const std::string& bucket,
const std::string& object, GcsFileStat* stat,
gcs::Client* gcs_client, TF_Status* status) {
auto metadata = gcs_client->GetObjectMetadata(bucket, object);
if (!metadata) return TF_SetStatusFromGCSStatus(metadata.status(), status);
stat->generation_number = metadata->generation();
stat->base.length = metadata->size();
stat->base.mtime_nsec =
metadata->time_storage_class_updated().time_since_epoch().count();
stat->base.is_directory = object.back() == '/';
return TF_SetStatus(status, TF_OK, "");
}
// TODO(vnvo2409): Implement later
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
@ -325,8 +486,46 @@ void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
bool is_cache_enabled;
{
absl::MutexLock l(&gcs_file->block_cache_lock);
is_cache_enabled = gcs_file->file_block_cache->IsCacheEnabled();
}
auto read_fn = [gcs_file, is_cache_enabled, bucket, object](
const std::string& path, uint64_t offset, size_t n,
char* buffer, TF_Status* status) -> int64_t {
int64_t read = 0;
if (is_cache_enabled) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
GcsFileStat stat;
gcs_file->stat_cache->LookupOrCompute(
path, &stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
if (TF_GetCode(status) != TF_OK) return -1;
if (!gcs_file->file_block_cache->ValidateAndUpdateFileSignature(
path, stat.generation_number)) {
std::cout
<< "File signature has been changed. Refreshing the cache. Path: "
<< path;
}
read = gcs_file->file_block_cache->Read(path, offset, n, buffer, status);
} else {
read = LoadBufferFromGCS(path, offset, n, buffer, gcs_file, status);
}
if (TF_GetCode(status) != TF_OK) return -1;
if (read < n)
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
else
TF_SetStatus(status, TF_OK, "");
return read;
};
file->plugin_file = new tf_random_access_file::GCSFile(
{std::move(bucket), std::move(object), &gcs_file->gcs_client});
std::move(path), is_cache_enabled, gcs_file->block_size, read_fn);
TF_SetStatus(status, TF_OK, "");
}

View File

@ -17,6 +17,8 @@
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/expiring_lru_cache.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
@ -45,10 +47,23 @@ uint64_t Length(const TF_ReadOnlyMemoryRegion* region);
} // namespace tf_read_only_memory_region
namespace tf_gcs_filesystem {
typedef struct GcsFileStat {
TF_FileStatistics base;
int64_t generation_number;
} GcsFileStat;
typedef struct GCSFile {
google::cloud::storage::Client gcs_client; // owned
bool compose;
absl::Mutex block_cache_lock;
std::shared_ptr<RamFileBlockCache> file_block_cache
ABSL_GUARDED_BY(block_cache_lock);
uint64_t block_size; // Reads smaller than block_size will trigger a read
// of block_size.
std::unique_ptr<ExpiringLRUCache<GcsFileStat>> stat_cache;
GCSFile(google::cloud::storage::Client&& gcs_client);
} GCSFile;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,

View File

@ -39,9 +39,6 @@ std::shared_ptr<RamFileBlockCache::Block> RamFileBlockCache::Lookup(
auto entry = block_map_.find(key);
if (entry != block_map_.end()) {
if (BlockNotStale(entry->second)) {
if (cache_stats_ != nullptr) {
cache_stats_->RecordCacheHitBlockSize(entry->second->data.size());
}
return entry->second;
} else {
// Remove the stale block and continue.
@ -136,12 +133,9 @@ void RamFileBlockCache::MaybeFetch(const Key& key,
block->mu.Unlock(); // Release the lock while making the API call.
block->data.clear();
block->data.resize(block_size_, 0);
size_t bytes_transferred;
block_fetcher_(key.first, key.second, block_size_, block->data.data(),
&bytes_transferred, status);
if (cache_stats_ != nullptr) {
cache_stats_->RecordCacheMissBlockSize(bytes_transferred);
}
int64_t bytes_transferred;
bytes_transferred = block_fetcher_(key.first, key.second, block_size_,
block->data.data(), status);
block->mu.Lock(); // Reacquire the lock immediately afterwards
if (TF_GetCode(status) == TF_OK) {
block->data.resize(bytes_transferred, 0);
@ -171,18 +165,16 @@ void RamFileBlockCache::MaybeFetch(const Key& key,
"Control flow should never reach the end of RamFileBlockCache::Fetch.");
}
void RamFileBlockCache::Read(const std::string& filename, size_t offset,
size_t n, char* buffer, size_t* bytes_transferred,
TF_Status* status) {
*bytes_transferred = 0;
int64_t RamFileBlockCache::Read(const std::string& filename, size_t offset,
size_t n, char* buffer, TF_Status* status) {
if (n == 0) {
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return 0;
}
if (!IsCacheEnabled() || (n > max_bytes_)) {
// The cache is effectively disabled, so we pass the read through to the
// fetcher without breaking it up into blocks.
return block_fetcher_(filename, offset, n, buffer, bytes_transferred,
status);
return block_fetcher_(filename, offset, n, buffer, status);
}
// Calculate the block-aligned start and end of the read.
size_t start = block_size_ * (offset / block_size_);
@ -202,20 +194,20 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset,
abort();
}
MaybeFetch(key, block, status);
if (TF_GetCode(status) != TF_OK) return;
if (TF_GetCode(status) != TF_OK) return -1;
UpdateLRU(key, block, status);
if (TF_GetCode(status) != TF_OK) return;
if (TF_GetCode(status) != TF_OK) return -1;
// Copy the relevant portion of the block into the result buffer.
const auto& data = block->data;
if (offset >= pos + data.size()) {
// The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`.
*bytes_transferred = total_bytes_transferred;
std::stringstream os;
os << "EOF at offset " << offset << " in file " << filename
<< " at position " << pos << " with data size " << data.size();
return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str());
return total_bytes_transferred;
}
auto begin = data.begin();
if (offset > pos) {
@ -237,8 +229,8 @@ void RamFileBlockCache::Read(const std::string& filename, size_t offset,
break;
}
}
*bytes_transferred = total_bytes_transferred;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return total_bytes_transferred;
}
bool RamFileBlockCache::ValidateAndUpdateFileSignature(

View File

@ -28,7 +28,6 @@ limitations under the License.
#include "absl/synchronization/mutex.h"
#include "absl/synchronization/notification.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h"
#include "tensorflow/c/tf_status.h"
namespace tf_gcs_filesystem {
@ -37,16 +36,17 @@ namespace tf_gcs_filesystem {
///
/// This class should be shared by read-only random access files on a remote
/// filesystem (e.g. GCS).
class RamFileBlockCache : public FileBlockCache {
class RamFileBlockCache {
public:
/// The callback executed when a block is not found in the cache, and needs to
/// be fetched from the backing filesystem. This callback is provided when the
/// cache is constructed. The `status` should be `TF_OK` as long as the
/// read from the remote filesystem succeeded (similar to the semantics of the
/// read(2) system call).
typedef std::function<void(const std::string& filename, size_t offset,
size_t buffer_size, char* buffer,
size_t* bytes_transferred, TF_Status* status)>
/// cache is constructed. It returns total bytes read ( -1 in case of errors
/// ). The `status` should be `TF_OK` as long as the read from the remote
/// filesystem succeeded (similar to the semantics of the read(2) system
/// call).
typedef std::function<int64_t(const std::string& filename, size_t offset,
size_t buffer_size, char* buffer,
TF_Status* status)>
BlockFetcher;
RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness,
@ -66,10 +66,10 @@ class RamFileBlockCache : public FileBlockCache {
TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this));
}
std::cout << "GCS file block cache is "
<< (IsCacheEnabled() ? "enabled" : "disabled");
<< (IsCacheEnabled() ? "enabled" : "disabled") << ".\n";
}
~RamFileBlockCache() override {
~RamFileBlockCache() {
if (pruning_thread_) {
stop_pruning_thread_.Notify();
// Destroying pruning_thread_ will block until Prune() receives the above
@ -78,8 +78,9 @@ class RamFileBlockCache : public FileBlockCache {
}
}
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. This
/// method will set `status` to:
/// Read `n` bytes from `filename` starting at `offset` into `buffer`. It
/// returns total bytes read ( -1 in case of errors ). This method will set
/// `status` to:
///
/// 1) The error from the remote filesystem, if the read from the remote
/// filesystem failed.
@ -97,37 +98,34 @@ class RamFileBlockCache : public FileBlockCache {
///
/// Caller is responsible for allocating memory for `buffer`.
/// `buffer` will be left unchanged in case of errors.
void Read(const std::string& filename, size_t offset, size_t n, char* buffer,
size_t* bytes_transferred, TF_Status* status) override;
int64_t Read(const std::string& filename, size_t offset, size_t n,
char* buffer, TF_Status* status);
// Validate the given file signature with the existing file signature in the
// cache. Returns true if the signature doesn't change or the file doesn't
// exist before. If the signature changes, update the existing signature with
// the new one and remove the file from cache.
bool ValidateAndUpdateFileSignature(const std::string& filename,
int64_t file_signature) override
int64_t file_signature)
ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all cached blocks for `filename`.
void RemoveFile(const std::string& filename) override
ABSL_LOCKS_EXCLUDED(mu_);
void RemoveFile(const std::string& filename) ABSL_LOCKS_EXCLUDED(mu_);
/// Remove all cached data.
void Flush() override ABSL_LOCKS_EXCLUDED(mu_);
void Flush() ABSL_LOCKS_EXCLUDED(mu_);
/// Accessors for cache parameters.
size_t block_size() const override { return block_size_; }
size_t max_bytes() const override { return max_bytes_; }
uint64_t max_staleness() const override { return max_staleness_; }
size_t block_size() const { return block_size_; }
size_t max_bytes() const { return max_bytes_; }
uint64_t max_staleness() const { return max_staleness_; }
/// The current size (in bytes) of the cache.
size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_);
size_t CacheSize() const ABSL_LOCKS_EXCLUDED(mu_);
// Returns true if the cache is enabled. If false, the BlockFetcher callback
// is always executed during Read.
bool IsCacheEnabled() const override {
return block_size_ > 0 && max_bytes_ > 0;
}
bool IsCacheEnabled() const { return block_size_ > 0 && max_bytes_ > 0; }
// We can not pass a lambda with capture as a function pointer to
// `TF_StartThread`, so we have to wrap `Prune` inside a static function.

View File

@ -33,20 +33,22 @@ Status ReadCache(tf_gcs_filesystem::RamFileBlockCache* cache,
std::vector<char>* out) {
out->clear();
out->resize(n, 0);
size_t bytes_transferred = 0;
TF_Status status;
cache->Read(filename, offset, n, out->data(), &bytes_transferred, &status);
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
auto bytes_transferred =
cache->Read(filename, offset, n, out->data(), &status);
if (bytes_transferred >= 0) {
EXPECT_LE(bytes_transferred, n);
out->resize(bytes_transferred, n);
}
return status.status;
}
TEST(RamFileBlockCacheTest, IsCacheEnabled) {
auto fetcher = [](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
// Do nothing.
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return 0;
};
tf_gcs_filesystem::RamFileBlockCache cache1(0, 0, 0, fetcher);
tf_gcs_filesystem::RamFileBlockCache cache2(16, 0, 0, fetcher);
@ -62,12 +64,11 @@ TEST(RamFileBlockCacheTest, IsCacheEnabled) {
TEST(RamFileBlockCacheTest, ValidateAndUpdateFileSignature) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
string filename = "file";
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
@ -96,15 +97,14 @@ TEST(RamFileBlockCacheTest, PassThrough) {
int calls = 0;
auto fetcher = [&calls, want_filename, want_offset, want_n](
const string& got_filename, size_t got_offset,
size_t got_n, char* buffer, size_t* bytes_transferred,
TF_Status* status) {
size_t got_n, char* buffer, TF_Status* status) -> int64_t {
EXPECT_EQ(got_filename, want_filename);
EXPECT_EQ(got_offset, want_offset);
EXPECT_EQ(got_n, want_n);
calls++;
memset(buffer, 'x', got_n);
*bytes_transferred = got_n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return got_n;
};
// If block_size, max_bytes, or both are zero, or want_n is larger than
// max_bytes the cache is a pass-through.
@ -133,16 +133,17 @@ TEST(RamFileBlockCacheTest, BlockAlignment) {
}
// The fetcher just fetches slices of the buffer.
auto fetcher = [&buf](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
int64_t bytes_transferred;
if (offset < buf.size()) {
size_t bytes_to_copy = std::min<size_t>(buf.size() - offset, n);
memcpy(buffer, buf.data() + offset, bytes_to_copy);
*bytes_transferred = bytes_to_copy;
bytes_transferred = bytes_to_copy;
} else {
*bytes_transferred = 0;
bytes_transferred = 0;
}
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return bytes_transferred;
};
for (size_t block_size = 2; block_size <= 4; block_size++) {
// Make a cache of N-byte block size (1 block) and verify that reads of
@ -181,15 +182,14 @@ TEST(RamFileBlockCacheTest, CacheHits) {
std::set<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, char* buffer,
size_t* bytes_transferred,
TF_Status* status) {
TF_Status* status) -> int64_t {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
EXPECT_EQ(calls.find(offset), calls.end()) << "at offset " << offset;
calls.insert(offset);
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
const uint32 block_count = 256;
tf_gcs_filesystem::RamFileBlockCache cache(
@ -215,8 +215,7 @@ TEST(RamFileBlockCacheTest, OutOfRange) {
bool second_block = false;
auto fetcher = [block_size, file_size, &first_block, &second_block](
const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
size_t bytes_to_copy = 0;
@ -231,8 +230,8 @@ TEST(RamFileBlockCacheTest, OutOfRange) {
memset(buffer, 'x', bytes_to_copy);
second_block = true;
}
*bytes_transferred = bytes_to_copy;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return bytes_to_copy;
};
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher);
@ -260,14 +259,13 @@ TEST(RamFileBlockCacheTest, Inconsistent) {
const size_t block_size = 16;
// This fetcher returns OK but only fills in one byte for any offset.
auto fetcher = [block_size](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset % block_size, 0);
EXPECT_GE(n, 1);
memset(buffer, 'x', 1);
*bytes_transferred = 1;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return 1;
};
tf_gcs_filesystem::RamFileBlockCache cache(block_size, 2 * block_size, 0,
fetcher);
@ -286,8 +284,7 @@ TEST(RamFileBlockCacheTest, LRU) {
std::list<size_t> calls;
auto fetcher = [&calls, block_size](const string& filename, size_t offset,
size_t n, char* buffer,
size_t* bytes_transferred,
TF_Status* status) {
TF_Status* status) -> int64_t {
EXPECT_EQ(n, block_size);
EXPECT_FALSE(calls.empty()) << "at offset = " << offset;
if (!calls.empty()) {
@ -295,8 +292,8 @@ TEST(RamFileBlockCacheTest, LRU) {
calls.pop_front();
}
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
const uint32 block_count = 2;
tf_gcs_filesystem::RamFileBlockCache cache(
@ -335,12 +332,11 @@ TEST(RamFileBlockCacheTest, LRU) {
TEST(RamFileBlockCacheTest, MaxStaleness) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
std::vector<char> out;
std::unique_ptr<NowSecondsEnv> env(new NowSecondsEnv);
@ -380,8 +376,7 @@ TEST(RamFileBlockCacheTest, MaxStaleness) {
TEST(RamFileBlockCacheTest, RemoveFile) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
calls++;
char c = (filename == "a") ? 'a' : (filename == "b") ? 'b' : 'x';
if (offset > 0) {
@ -389,8 +384,8 @@ TEST(RamFileBlockCacheTest, RemoveFile) {
c = toupper(c);
}
memset(buffer, c, n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
// This cache has space for 4 blocks; we'll read from two files.
const size_t n = 3;
@ -443,12 +438,11 @@ TEST(RamFileBlockCacheTest, RemoveFile) {
TEST(RamFileBlockCacheTest, Prune) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
std::vector<char> out;
// Our fake environment is initialized with the current timestamp.
@ -509,17 +503,17 @@ TEST(RamFileBlockCacheTest, ParallelReads) {
const int callers = 4;
BlockingCounter counter(callers);
auto fetcher = [&counter](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
counter.DecrementCount();
if (!counter.WaitFor(std::chrono::seconds(10))) {
// This avoids having the test time out, which is harder to debug.
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"desired concurrency not reached");
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"desired concurrency not reached");
return -1;
}
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
const int block_size = 8;
tf_gcs_filesystem::RamFileBlockCache cache(
@ -548,17 +542,16 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
Notification notification;
auto fetcher = [&num_requests, &notification, block_size](
const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
EXPECT_EQ(n, block_size);
EXPECT_EQ(offset, 0);
num_requests++;
memset(buffer, 'x', n);
*bytes_transferred = n;
notification.Notify();
// Wait for other thread to issue read.
Env::Default()->SleepForMicroseconds(100000); // 0.1 secs
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
tf_gcs_filesystem::RamFileBlockCache cache(block_size, block_size, 0,
fetcher);
@ -580,12 +573,11 @@ TEST(RamFileBlockCacheTest, CoalesceConcurrentReads) {
TEST(RamFileBlockCacheTest, Flush) {
int calls = 0;
auto fetcher = [&calls](const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred,
TF_Status* status) {
char* buffer, TF_Status* status) -> int64_t {
calls++;
memset(buffer, 'x', n);
*bytes_transferred = n;
return TF_SetStatus(status, TF_OK, "");
TF_SetStatus(status, TF_OK, "");
return n;
};
tf_gcs_filesystem::RamFileBlockCache cache(16, 32, 0, fetcher);
std::vector<char> out;

View File

@ -0,0 +1,46 @@
# Experimental gcs filesystem plugin.
load("//tensorflow:tensorflow.bzl", "get_win_copts", "tf_cc_shared_object")
package(
licenses = ["notice"], # Apache 2.0
)
# Filesystem implementation for GCS environments
tf_cc_shared_object(
name = "s3_filesystem",
framework_so = [],
linkstatic = False,
per_os_targets = 1,
visibility = ["//visibility:public"],
deps = [":s3_filesystem_impl"],
)
# The real implementation of the filesystem.
cc_library(
name = "s3_filesystem_impl",
srcs = ["s3_filesystem.cc"],
hdrs = ["s3_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
}),
deps = [
":aws_crypto",
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@aws",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
)
cc_library(
name = "aws_crypto",
srcs = ["aws_crypto.cc"],
hdrs = ["aws_crypto.h"],
deps = [
"@aws",
"@boringssl//:crypto",
],
alwayslink = 1,
)

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include <aws/core/utils/crypto/HashResult.h>
#include <aws/s3/S3Client.h>
#include <openssl/hmac.h>
#include <openssl/rand.h>
#include <openssl/sha.h>
namespace tf_s3_filesystem {
class AWSSha256HMACOpenSSLImpl : public Aws::Utils::Crypto::HMAC {
public:
AWSSha256HMACOpenSSLImpl() {}
virtual ~AWSSha256HMACOpenSSLImpl() = default;
Aws::Utils::Crypto::HashResult Calculate(
const Aws::Utils::ByteBuffer& toSign,
const Aws::Utils::ByteBuffer& secret) override {
unsigned int length = SHA256_DIGEST_LENGTH;
Aws::Utils::ByteBuffer digest(length);
memset(digest.GetUnderlyingData(), 0, length);
HMAC_CTX ctx;
HMAC_CTX_init(&ctx);
HMAC_Init_ex(&ctx, secret.GetUnderlyingData(),
static_cast<int>(secret.GetLength()), EVP_sha256(), NULL);
HMAC_Update(&ctx, toSign.GetUnderlyingData(), toSign.GetLength());
HMAC_Final(&ctx, digest.GetUnderlyingData(), &length);
HMAC_CTX_cleanup(&ctx);
return Aws::Utils::Crypto::HashResult(std::move(digest));
}
};
class AWSSha256OpenSSLImpl : public Aws::Utils::Crypto::Hash {
public:
AWSSha256OpenSSLImpl() {}
virtual ~AWSSha256OpenSSLImpl() = default;
Aws::Utils::Crypto::HashResult Calculate(const Aws::String& str) override {
SHA256_CTX sha256;
SHA256_Init(&sha256);
SHA256_Update(&sha256, str.data(), str.size());
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
SHA256_Final(hash.GetUnderlyingData(), &sha256);
return Aws::Utils::Crypto::HashResult(std::move(hash));
}
Aws::Utils::Crypto::HashResult Calculate(Aws::IStream& stream) override {
SHA256_CTX sha256;
SHA256_Init(&sha256);
auto currentPos = stream.tellg();
if (currentPos == std::streampos(std::streamoff(-1))) {
currentPos = 0;
stream.clear();
}
stream.seekg(0, stream.beg);
char streamBuffer
[Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE];
while (stream.good()) {
stream.read(streamBuffer,
Aws::Utils::Crypto::Hash::INTERNAL_HASH_STREAM_BUFFER_SIZE);
auto bytesRead = stream.gcount();
if (bytesRead > 0) {
SHA256_Update(&sha256, streamBuffer, static_cast<size_t>(bytesRead));
}
}
stream.clear();
stream.seekg(currentPos, stream.beg);
Aws::Utils::ByteBuffer hash(SHA256_DIGEST_LENGTH);
SHA256_Final(hash.GetUnderlyingData(), &sha256);
return Aws::Utils::Crypto::HashResult(std::move(hash));
}
};
class AWSSecureRandomBytesImpl : public Aws::Utils::Crypto::SecureRandomBytes {
public:
AWSSecureRandomBytesImpl() {}
virtual ~AWSSecureRandomBytesImpl() = default;
void GetBytes(unsigned char* buffer, size_t bufferSize) override {
assert(buffer);
int success = RAND_bytes(buffer, static_cast<int>(bufferSize));
if (success != 1) {
m_failure = true;
}
}
private:
bool m_failure;
};
std::shared_ptr<Aws::Utils::Crypto::Hash>
AWSSHA256Factory::CreateImplementation() const {
return Aws::MakeShared<AWSSha256OpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::HMAC>
AWSSHA256HmacFactory::CreateImplementation() const {
return Aws::MakeShared<AWSSha256HMACOpenSSLImpl>(AWSCryptoAllocationTag);
}
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes>
AWSSecureRandomFactory::CreateImplementation() const {
return Aws::MakeShared<AWSSecureRandomBytesImpl>(AWSCryptoAllocationTag);
}
} // namespace tf_s3_filesystem

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_
#include <aws/core/Aws.h>
#include <aws/core/utils/crypto/Factories.h>
#include <aws/core/utils/crypto/HMAC.h>
#include <aws/core/utils/crypto/Hash.h>
#include <aws/core/utils/crypto/SecureRandom.h>
namespace tf_s3_filesystem {
constexpr char AWSCryptoAllocationTag[] = "AWSCryptoAllocation";
class AWSSHA256Factory : public Aws::Utils::Crypto::HashFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::Hash> CreateImplementation()
const override;
};
class AWSSHA256HmacFactory : public Aws::Utils::Crypto::HMACFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::HMAC> CreateImplementation()
const override;
};
class AWSSecureRandomFactory : public Aws::Utils::Crypto::SecureRandomFactory {
public:
std::shared_ptr<Aws::Utils::Crypto::SecureRandomBytes> CreateImplementation()
const override;
};
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_AWS_CRYPTO_H_

View File

@ -0,0 +1,981 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/s3/s3_filesystem.h"
#include <aws/core/client/AsyncCallerContext.h>
#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/utils/stream/PreallocatedStreamBuf.h>
#include <aws/s3/model/AbortMultipartUploadRequest.h>
#include <aws/s3/model/CompleteMultipartUploadRequest.h>
#include <aws/s3/model/CompletedMultipartUpload.h>
#include <aws/s3/model/CompletedPart.h>
#include <aws/s3/model/CopyObjectRequest.h>
#include <aws/s3/model/CreateMultipartUploadRequest.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/s3/model/ListObjectsRequest.h>
#include <aws/s3/model/UploadPartCopyRequest.h>
#include <stdlib.h>
#include <string.h>
#include "absl/strings/ascii.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/s3/aws_crypto.h"
#include "tensorflow/c/tf_status.h"
// Implementation of a filesystem for S3 environments.
// This filesystem will support `s3://` URI schemes.
constexpr char kS3FileSystemAllocationTag[] = "S3FileSystemAllocation";
constexpr char kS3ClientAllocationTag[] = "S3ClientAllocation";
constexpr int64_t kS3TimeoutMsec = 300000; // 5 min
constexpr char kExecutorTag[] = "TransferManagerExecutorAllocation";
constexpr int kExecutorPoolSize = 25;
constexpr uint64_t kS3MultiPartUploadChunkSize = 50 * 1024 * 1024; // 50 MB
constexpr uint64_t kS3MultiPartDownloadChunkSize = 50 * 1024 * 1024; // 50 MB
constexpr size_t kDownloadRetries = 3;
constexpr size_t kUploadRetries = 3;
constexpr size_t kS3ReadAppendableFileBufferSize = 1024 * 1024; // 1 MB
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
static inline void TF_SetStatusFromAWSError(
const Aws::Client::AWSError<Aws::S3::S3Errors>& error, TF_Status* status) {
switch (error.GetResponseCode()) {
case Aws::Http::HttpResponseCode::FORBIDDEN:
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"AWS Credentials have not been set properly. "
"Unable to access the specified S3 location");
break;
case Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE:
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
break;
case Aws::Http::HttpResponseCode::NOT_FOUND:
TF_SetStatus(status, TF_NOT_FOUND, error.GetMessage().c_str());
break;
default:
TF_SetStatus(
status, TF_UNKNOWN,
(error.GetExceptionName() + ": " + error.GetMessage()).c_str());
break;
}
}
static void ParseS3Path(const Aws::String& fname, bool object_empty_ok,
Aws::String* bucket, Aws::String* object,
TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "s3://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"S3 path doesn't start with 's3://'.");
return;
}
size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"S3 path doesn't contain a bucket name.");
return;
}
*bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*object = fname.substr(bucket_end + 1);
if (object->empty() && !object_empty_ok) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"S3 path doesn't contain an object name.");
}
}
static Aws::Client::ClientConfiguration& GetDefaultClientConfig() {
ABSL_CONST_INIT static absl::Mutex cfg_lock(absl::kConstInit);
static bool init(false);
static Aws::Client::ClientConfiguration cfg;
absl::MutexLock l(&cfg_lock);
if (!init) {
const char* endpoint = getenv("S3_ENDPOINT");
if (endpoint) cfg.endpointOverride = Aws::String(endpoint);
const char* region = getenv("AWS_REGION");
// TODO (yongtang): `S3_REGION` should be deprecated after 2.0.
if (!region) region = getenv("S3_REGION");
if (region) {
cfg.region = Aws::String(region);
} else {
// Load config file (e.g., ~/.aws/config) only if AWS_SDK_LOAD_CONFIG
// is set with a truthy value.
const char* load_config_env = getenv("AWS_SDK_LOAD_CONFIG");
std::string load_config =
load_config_env ? absl::AsciiStrToLower(load_config_env) : "";
if (load_config == "true" || load_config == "1") {
Aws::String config_file;
// If AWS_CONFIG_FILE is set then use it, otherwise use ~/.aws/config.
const char* config_file_env = getenv("AWS_CONFIG_FILE");
if (config_file_env) {
config_file = config_file_env;
} else {
const char* home_env = getenv("HOME");
if (home_env) {
config_file = home_env;
config_file += "/.aws/config";
}
}
Aws::Config::AWSConfigFileProfileConfigLoader loader(config_file);
loader.Load();
auto profiles = loader.GetProfiles();
if (!profiles["default"].GetRegion().empty())
cfg.region = profiles["default"].GetRegion();
}
}
const char* use_https = getenv("S3_USE_HTTPS");
if (use_https) {
if (use_https[0] == '0')
cfg.scheme = Aws::Http::Scheme::HTTP;
else
cfg.scheme = Aws::Http::Scheme::HTTPS;
}
const char* verify_ssl = getenv("S3_VERIFY_SSL");
if (verify_ssl) {
if (verify_ssl[0] == '0')
cfg.verifySSL = false;
else
cfg.verifySSL = true;
}
// if these timeouts are low, you may see an error when
// uploading/downloading large files: Unable to connect to endpoint
int64_t timeout;
cfg.connectTimeoutMs =
absl::SimpleAtoi(getenv("S3_CONNECT_TIMEOUT_MSEC"), &timeout)
? timeout
: kS3TimeoutMsec;
cfg.requestTimeoutMs =
absl::SimpleAtoi(getenv("S3_REQUEST_TIMEOUT_MSEC"), &timeout)
? timeout
: kS3TimeoutMsec;
const char* ca_file = getenv("S3_CA_FILE");
if (ca_file) cfg.caFile = Aws::String(ca_file);
const char* ca_path = getenv("S3_CA_PATH");
if (ca_path) cfg.caPath = Aws::String(ca_path);
init = true;
}
return cfg;
};
static void GetS3Client(tf_s3_filesystem::S3File* s3_file) {
absl::MutexLock l(&s3_file->initialization_lock);
if (s3_file->s3_client.get() == nullptr) {
Aws::SDKOptions options;
options.cryptoOptions.sha256Factory_create_fn = []() {
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256Factory>(
tf_s3_filesystem::AWSCryptoAllocationTag);
};
options.cryptoOptions.sha256HMACFactory_create_fn = []() {
return Aws::MakeShared<tf_s3_filesystem::AWSSHA256HmacFactory>(
tf_s3_filesystem::AWSCryptoAllocationTag);
};
options.cryptoOptions.secureRandomFactory_create_fn = []() {
return Aws::MakeShared<tf_s3_filesystem::AWSSecureRandomFactory>(
tf_s3_filesystem::AWSCryptoAllocationTag);
};
Aws::InitAPI(options);
// The creation of S3Client disables virtual addressing:
// S3Client(clientConfiguration, signPayloads, useVirtualAddressing =
// true)
// The purpose is to address the issue encountered when there is an `.`
// in the bucket name. Due to TLS hostname validation or DNS rules,
// the bucket may not be resolved. Disabling of virtual addressing
// should address the issue. See GitHub issue 16397 for details.
s3_file->s3_client = Aws::MakeShared<Aws::S3::S3Client>(
kS3ClientAllocationTag, GetDefaultClientConfig(),
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, false);
}
}
static void GetExecutor(tf_s3_filesystem::S3File* s3_file) {
absl::MutexLock l(&s3_file->initialization_lock);
if (s3_file->executor.get() == nullptr) {
s3_file->executor =
Aws::MakeShared<Aws::Utils::Threading::PooledThreadExecutor>(
kExecutorTag, kExecutorPoolSize);
}
}
static void GetTransferManager(
const Aws::Transfer::TransferDirection& direction,
tf_s3_filesystem::S3File* s3_file) {
absl::MutexLock l(&s3_file->initialization_lock);
if (s3_file->transfer_managers[direction].get() == nullptr) {
GetS3Client(s3_file);
GetExecutor(s3_file);
Aws::Transfer::TransferManagerConfiguration config(s3_file->executor.get());
config.s3Client = s3_file->s3_client;
config.bufferSize = s3_file->multi_part_chunk_sizes[direction];
// must be larger than pool size * multi part chunk size
config.transferBufferMaxHeapSize =
(kExecutorPoolSize + 1) * s3_file->multi_part_chunk_sizes[direction];
s3_file->transfer_managers[direction] =
Aws::Transfer::TransferManager::Create(config);
}
}
static void ShutdownClient(Aws::S3::S3Client* s3_client) {
if (s3_client != nullptr) {
delete s3_client;
Aws::SDKOptions options;
Aws::ShutdownAPI(options);
}
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
// ----------------------------------------------------------------------------
namespace tf_random_access_file {
typedef struct S3File {
Aws::String bucket;
Aws::String object;
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager;
bool use_multi_part_download;
} S3File;
// AWS Streams destroy the buffer (buf) passed, so creating a new
// IOStream that retains the buffer so the calling function
// can control it's lifecycle
class TFS3UnderlyingStream : public Aws::IOStream {
public:
using Base = Aws::IOStream;
TFS3UnderlyingStream(std::streambuf* buf) : Base(buf) {}
virtual ~TFS3UnderlyingStream() = default;
};
void Cleanup(TF_RandomAccessFile* file) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
delete s3_file;
}
static int64_t ReadS3Client(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
Aws::S3::Model::GetObjectRequest get_object_request;
get_object_request.WithBucket(s3_file->bucket).WithKey(s3_file->bucket);
Aws::String bytes =
absl::StrCat("bytes=", offset, "-", offset + n - 1).c_str();
get_object_request.SetRange(bytes);
get_object_request.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
auto get_object_outcome = s3_file->s3_client->GetObject(get_object_request);
if (!get_object_outcome.IsSuccess())
TF_SetStatusFromAWSError(get_object_outcome.GetError(), status);
else
TF_SetStatus(status, TF_OK, "");
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_OUT_OF_RANGE)
return -1;
int64_t read = get_object_outcome.GetResult().GetContentLength();
if (read < n)
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
get_object_outcome.GetResult().GetBody().read(buffer, read);
return read;
}
static int64_t ReadS3TransferManager(S3File* s3_file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto create_download_stream = [&]() {
return Aws::New<TFS3UnderlyingStream>(
"S3ReadStream",
Aws::New<Aws::Utils::Stream::PreallocatedStreamBuf>(
"S3ReadStream", reinterpret_cast<unsigned char*>(buffer), n));
};
auto handle = s3_file->transfer_manager->DownloadFile(
s3_file->bucket, s3_file->object, offset, n, create_download_stream);
handle->WaitUntilFinished();
size_t retries = 0;
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
handle->GetLastError().GetResponseCode() !=
Aws::Http::HttpResponseCode::REQUESTED_RANGE_NOT_SATISFIABLE &&
retries++ < kDownloadRetries) {
// Only failed parts will be downloaded again.
s3_file->transfer_manager->RetryDownload(handle);
handle->WaitUntilFinished();
}
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED)
TF_SetStatusFromAWSError(handle->GetLastError(), status);
else
TF_SetStatus(status, TF_OK, "");
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_OUT_OF_RANGE)
return -1;
int64_t read = handle->GetBytesTransferred();
if (read < n)
TF_SetStatus(status, TF_OUT_OF_RANGE, "Read less bytes than requested");
return read;
}
int64_t Read(const TF_RandomAccessFile* file, uint64_t offset, size_t n,
char* buffer, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (s3_file->use_multi_part_download)
return ReadS3TransferManager(s3_file, offset, n, buffer, status);
else
return ReadS3Client(s3_file, offset, n, buffer, status);
}
} // namespace tf_random_access_file
// SECTION 2. Implementation for `TF_WritableFile`
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct S3File {
Aws::String bucket;
Aws::String object;
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager;
bool sync_needed;
std::shared_ptr<Aws::Utils::TempFile> outfile;
S3File(Aws::String bucket, Aws::String object,
std::shared_ptr<Aws::S3::S3Client> s3_client,
std::shared_ptr<Aws::Transfer::TransferManager> transfer_manager)
: bucket(bucket),
object(object),
s3_client(s3_client),
transfer_manager(transfer_manager),
outfile(Aws::MakeShared<Aws::Utils::TempFile>(
kS3FileSystemAllocationTag, nullptr, "_s3_filesystem_XXXXXX",
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
std::ios_base::out)) {}
} S3File;
void Cleanup(TF_WritableFile* file) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
delete s3_file;
}
void Append(const TF_WritableFile* file, const char* buffer, size_t n,
TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (!s3_file->outfile) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"The internal temporary file is not writable.");
return;
}
s3_file->sync_needed = true;
s3_file->outfile->write(buffer, n);
if (!s3_file->outfile->good())
TF_SetStatus(status, TF_INTERNAL,
"Could not append to the internal temporary file.");
else
TF_SetStatus(status, TF_OK, "");
}
int64_t Tell(const TF_WritableFile* file, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
if (position == -1)
TF_SetStatus(status, TF_INTERNAL,
"tellp on the internal temporary file failed");
else
TF_SetStatus(status, TF_OK, "");
return position;
}
void Sync(const TF_WritableFile* file, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (!s3_file->outfile) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"The internal temporary file is not writable.");
return;
}
if (!s3_file->sync_needed) {
TF_SetStatus(status, TF_OK, "");
return;
}
auto position = static_cast<int64_t>(s3_file->outfile->tellp());
auto handle = s3_file->transfer_manager->UploadFile(
s3_file->outfile, s3_file->bucket, s3_file->object,
"application/octet-stream", Aws::Map<Aws::String, Aws::String>());
handle->WaitUntilFinished();
size_t retries = 0;
while (handle->GetStatus() == Aws::Transfer::TransferStatus::FAILED &&
retries++ < kUploadRetries) {
// if multipart upload was used, only the failed parts will be re-sent
s3_file->transfer_manager->RetryUpload(s3_file->outfile, handle);
handle->WaitUntilFinished();
}
if (handle->GetStatus() != Aws::Transfer::TransferStatus::COMPLETED)
return TF_SetStatusFromAWSError(handle->GetLastError(), status);
s3_file->outfile->clear();
s3_file->outfile->seekp(position);
s3_file->sync_needed = false;
TF_SetStatus(status, TF_OK, "");
}
void Flush(const TF_WritableFile* file, TF_Status* status) {
Sync(file, status);
}
void Close(const TF_WritableFile* file, TF_Status* status) {
auto s3_file = static_cast<S3File*>(file->plugin_file);
if (s3_file->outfile) {
Sync(file, status);
if (TF_GetCode(status) != TF_OK) return;
s3_file->outfile.reset();
}
TF_SetStatus(status, TF_OK, "");
}
} // namespace tf_writable_file
// SECTION 3. Implementation for `TF_ReadOnlyMemoryRegion`
// ----------------------------------------------------------------------------
namespace tf_read_only_memory_region {
typedef struct S3MemoryRegion {
std::unique_ptr<char[]> data;
uint64_t length;
} S3MemoryRegion;
void Cleanup(TF_ReadOnlyMemoryRegion* region) {
auto r = static_cast<S3MemoryRegion*>(region->plugin_memory_region);
delete r;
}
const void* Data(const TF_ReadOnlyMemoryRegion* region) {
auto r = static_cast<S3MemoryRegion*>(region->plugin_memory_region);
return reinterpret_cast<const void*>(r->data.get());
}
uint64_t Length(const TF_ReadOnlyMemoryRegion* region) {
auto r = static_cast<S3MemoryRegion*>(region->plugin_memory_region);
return r->length;
}
} // namespace tf_read_only_memory_region
// SECTION 4. Implementation for `TF_Filesystem`, the actual filesystem
// ----------------------------------------------------------------------------
namespace tf_s3_filesystem {
S3File::S3File()
: s3_client(nullptr, ShutdownClient),
executor(nullptr),
transfer_managers(),
multi_part_chunk_sizes(),
use_multi_part_download(true),
initialization_lock() {
uint64_t temp_value;
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD] =
absl::SimpleAtoi(getenv("S3_MULTI_PART_UPLOAD_CHUNK_SIZE"), &temp_value)
? temp_value
: kS3MultiPartUploadChunkSize;
multi_part_chunk_sizes[Aws::Transfer::TransferDirection::DOWNLOAD] =
absl::SimpleAtoi(getenv("S3_MULTI_PART_DOWNLOAD_CHUNK_SIZE"), &temp_value)
? temp_value
: kS3MultiPartDownloadChunkSize;
use_multi_part_download =
absl::SimpleAtoi(getenv("S3_DISABLE_MULTI_PART_DOWNLOAD"), &temp_value)
? (temp_value != 1)
: use_multi_part_download;
transfer_managers.emplace(Aws::Transfer::TransferDirection::UPLOAD, nullptr);
transfer_managers.emplace(Aws::Transfer::TransferDirection::DOWNLOAD,
nullptr);
}
void Init(TF_Filesystem* filesystem, TF_Status* status) {
filesystem->plugin_filesystem = new S3File();
TF_SetStatus(status, TF_OK, "");
}
void Cleanup(TF_Filesystem* filesystem) {
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
delete s3_file;
}
void NewRandomAccessFile(const TF_Filesystem* filesystem, const char* path,
TF_RandomAccessFile* file, TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::DOWNLOAD, s3_file);
file->plugin_file = new tf_random_access_file::S3File(
{bucket, object, s3_file->s3_client,
s3_file->transfer_managers[Aws::Transfer::TransferDirection::DOWNLOAD],
s3_file->use_multi_part_download});
TF_SetStatus(status, TF_OK, "");
}
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
file->plugin_file = new tf_writable_file::S3File(
bucket, object, s3_file->s3_client,
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
TF_SetStatus(status, TF_OK, "");
}
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
// We need to delete `file->plugin_file` in case of errors. We set
// `file->plugin_file` to `nullptr` in order to avoid segment fault when
// calling deleter of `unique_ptr`.
file->plugin_file = nullptr;
std::unique_ptr<TF_WritableFile, void (*)(TF_WritableFile*)> writer(
file, [](TF_WritableFile* file) {
if (file != nullptr && file->plugin_file != nullptr) {
tf_writable_file::Cleanup(file);
}
});
writer->plugin_file = new tf_writable_file::S3File(
bucket, object, s3_file->s3_client,
s3_file->transfer_managers[Aws::Transfer::TransferDirection::UPLOAD]);
TF_SetStatus(status, TF_OK, "");
// Wraping inside a `std::unique_ptr` to prevent memory-leaking.
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
// when calling deleter of `unique_ptr`
reader->plugin_file = nullptr;
NewRandomAccessFile(filesystem, path, reader.get(), status);
if (TF_GetCode(status) != TF_OK) return;
uint64_t offset = 0;
std::string buffer(kS3ReadAppendableFileBufferSize, {});
while (true) {
auto read = tf_random_access_file::Read(reader.get(), offset,
kS3ReadAppendableFileBufferSize,
&buffer[0], status);
if (TF_GetCode(status) == TF_NOT_FOUND) {
break;
} else if (TF_GetCode(status) == TF_OK) {
offset += read;
tf_writable_file::Append(file, buffer.c_str(), read, status);
if (TF_GetCode(status) != TF_OK) return;
} else if (TF_GetCode(status) == TF_OUT_OF_RANGE) {
offset += read;
tf_writable_file::Append(file, buffer.c_str(), read, status);
if (TF_GetCode(status) != TF_OK) return;
break;
} else {
return;
}
}
writer.release();
TF_SetStatus(status, TF_OK, "");
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
if (object.empty()) {
Aws::S3::Model::HeadBucketRequest head_bucket_request;
head_bucket_request.WithBucket(bucket);
auto head_bucket_outcome =
s3_file->s3_client->HeadBucket(head_bucket_request);
if (!head_bucket_outcome.IsSuccess())
return TF_SetStatusFromAWSError(head_bucket_outcome.GetError(), status);
stats->length = 0;
stats->is_directory = 1;
stats->mtime_nsec = 0;
return TF_SetStatus(status, TF_OK, "");
}
bool found = false;
Aws::S3::Model::HeadObjectRequest head_object_request;
head_object_request.WithBucket(bucket).WithKey(object);
head_object_request.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
auto head_object_outcome =
s3_file->s3_client->HeadObject(head_object_request);
if (head_object_outcome.IsSuccess()) {
stats->length = head_object_outcome.GetResult().GetContentLength();
stats->is_directory = 0;
stats->mtime_nsec =
head_object_outcome.GetResult().GetLastModified().Millis() * 1e6;
found = true;
} else {
return TF_SetStatusFromAWSError(head_object_outcome.GetError(), status);
}
auto prefix = object;
if (prefix.back() != '/') {
prefix.push_back('/');
}
Aws::S3::Model::ListObjectsRequest list_objects_request;
list_objects_request.WithBucket(bucket).WithPrefix(prefix).WithMaxKeys(1);
list_objects_request.SetResponseStreamFactory(
[]() { return Aws::New<Aws::StringStream>(kS3FileSystemAllocationTag); });
auto list_objects_outcome =
s3_file->s3_client->ListObjects(list_objects_request);
if (list_objects_outcome.IsSuccess()) {
auto objects = list_objects_outcome.GetResult().GetContents();
if (objects.size() > 0) {
stats->length = 0;
stats->is_directory = 1;
stats->mtime_nsec = objects[0].GetLastModified().Millis() * 1e6;
found = true;
}
} else {
TF_SetStatusFromAWSError(list_objects_outcome.GetError(), status);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION) return;
}
if (!found)
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("Object ", path, " does not exist").c_str());
TF_SetStatus(status, TF_OK, "");
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_FileStatistics stats;
Stat(filesystem, path, &stats, status);
}
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
TF_FileStatistics stats;
Stat(filesystem, path, &stats, status);
return stats.length;
}
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status) {
Aws::String bucket, object;
ParseS3Path(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
auto size = GetFileSize(filesystem, path, status);
if (TF_GetCode(status) != TF_OK) return;
if (size == 0)
return TF_SetStatus(status, TF_INVALID_ARGUMENT, "File is empty");
std::unique_ptr<char[]> data(new char[size]);
// Wraping inside a `std::unique_ptr` to prevent memory-leaking.
std::unique_ptr<TF_RandomAccessFile, void (*)(TF_RandomAccessFile*)> reader(
new TF_RandomAccessFile, [](TF_RandomAccessFile* file) {
if (file != nullptr) {
if (file->plugin_file != nullptr)
tf_random_access_file::Cleanup(file);
delete file;
}
});
// We set `reader->plugin_file` to `nullptr` in order to avoid segment fault
// when calling deleter of `unique_ptr`
reader->plugin_file = nullptr;
NewRandomAccessFile(filesystem, path, reader.get(), status);
if (TF_GetCode(status) != TF_OK) return;
auto read =
tf_random_access_file::Read(reader.get(), 0, size, data.get(), status);
if (TF_GetCode(status) != TF_OK) return;
region->plugin_memory_region = new tf_read_only_memory_region::S3MemoryRegion(
{std::move(data), static_cast<uint64_t>(read)});
TF_SetStatus(status, TF_OK, "");
}
static void SimpleCopyFile(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, S3File* s3_file,
TF_Status* status) {
Aws::S3::Model::CopyObjectRequest copy_object_request;
copy_object_request.WithCopySource(source)
.WithBucket(bucket_dst)
.WithKey(object_dst);
auto copy_object_outcome =
s3_file->s3_client->CopyObject(copy_object_request);
if (!copy_object_outcome.IsSuccess())
TF_SetStatusFromAWSError(copy_object_outcome.GetError(), status);
else
TF_SetStatus(status, TF_OK, "");
};
using EtagOutcome =
Aws::Utils::Outcome<Aws::String, Aws::Client::AWSError<Aws::S3::S3Errors>>;
typedef struct MultipartCopyAsyncContext
: public Aws::Client::AsyncCallerContext {
int part_number;
int* num_finished_parts;
Aws::Vector<EtagOutcome>* etag_outcomes;
// lock and cv for multi part copy
absl::Mutex* multi_part_copy_mutex;
absl::CondVar* multi_part_copy_cv;
} MultipartCopyAsyncContext;
static void AbortMultiPartCopy(const Aws::String& bucket_dst,
const Aws::String& object_dst,
const Aws::String& upload_id, S3File* s3_file,
TF_Status* status) {
Aws::S3::Model::AbortMultipartUploadRequest request;
request.WithBucket(bucket_dst).WithKey(object_dst).WithUploadId(upload_id);
auto outcome = s3_file->s3_client->AbortMultipartUpload(request);
if (!outcome.IsSuccess())
TF_SetStatusFromAWSError(outcome.GetError(), status);
else
TF_SetStatus(status, TF_OK, "");
}
static void MultiPartCopyCallback(
const Aws::S3::Model::UploadPartCopyRequest& request,
const Aws::S3::Model::UploadPartCopyOutcome& outcome,
const std::shared_ptr<const MultipartCopyAsyncContext>& context) {
// Access to `etag_outcomes` should be thread-safe because of distinct
// `part_number`.
auto part_number = context->part_number;
auto etag_outcomes = context->etag_outcomes;
if (outcome.IsSuccess()) {
(*etag_outcomes)[part_number] =
outcome.GetResult().GetCopyPartResult().GetETag();
} else {
(*etag_outcomes)[part_number] = outcome.GetError();
}
{
absl::MutexLock l(context->multi_part_copy_mutex);
(*context->num_finished_parts)++;
context->multi_part_copy_cv->Signal();
}
}
static void MultiPartCopy(const Aws::String& source,
const Aws::String& bucket_dst,
const Aws::String& object_dst, const size_t num_parts,
const uint64_t file_size, S3File* s3_file,
TF_Status* status) {
Aws::S3::Model::CreateMultipartUploadRequest create_multipart_upload_request;
create_multipart_upload_request.WithBucket(bucket_dst).WithKey(object_dst);
GetS3Client(s3_file);
GetTransferManager(Aws::Transfer::TransferDirection::UPLOAD, s3_file);
auto create_multipart_upload_outcome =
s3_file->s3_client->CreateMultipartUpload(
create_multipart_upload_request);
if (!create_multipart_upload_outcome.IsSuccess())
return TF_SetStatusFromAWSError(create_multipart_upload_outcome.GetError(),
status);
auto upload_id = create_multipart_upload_outcome.GetResult().GetUploadId();
int num_finished_parts = 0;
// Keep track of `Outcome` of each upload part.
Aws::Vector<EtagOutcome> etag_outcomes(num_parts);
// Mutex which protects access of the part_states map.
absl::Mutex multi_part_copy_mutex;
// Condition variable to be used with above mutex for synchronization.
absl::CondVar multi_part_copy_cv;
auto chunk_size =
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
size_t retries = 0;
while (retries++ < 3) {
// Queue up parts.
for (auto part_number = 0; part_number < num_parts; ++part_number) {
if (etag_outcomes[part_number].IsSuccess()) continue;
uint64_t start_pos = part_number * chunk_size;
uint64_t end_pos = start_pos + chunk_size - 1;
if (end_pos >= file_size) end_pos = file_size - 1;
Aws::String range =
absl::StrCat("bytes=", start_pos, "-", end_pos).c_str();
Aws::S3::Model::UploadPartCopyRequest upload_part_copy_request;
upload_part_copy_request.WithBucket(bucket_dst)
.WithKey(object_dst)
.WithCopySource(source)
.WithCopySourceRange(range)
// S3 API partNumber starts from 1.
.WithPartNumber(part_number + 1)
.WithUploadId(upload_id);
auto multi_part_context =
Aws::MakeShared<MultipartCopyAsyncContext>("MultiPartCopyContext");
multi_part_context->part_number = part_number;
multi_part_context->num_finished_parts = &num_finished_parts;
multi_part_context->etag_outcomes = &etag_outcomes;
multi_part_context->multi_part_copy_mutex = &multi_part_copy_mutex;
multi_part_context->multi_part_copy_cv = &multi_part_copy_cv;
auto callback =
[](const Aws::S3::S3Client* client,
const Aws::S3::Model::UploadPartCopyRequest& request,
const Aws::S3::Model::UploadPartCopyOutcome& outcome,
const std::shared_ptr<const Aws::Client::AsyncCallerContext>&
context) {
auto multipart_context =
std::static_pointer_cast<const MultipartCopyAsyncContext>(
context);
MultiPartCopyCallback(request, outcome, multipart_context);
};
std::shared_ptr<const Aws::Client::AsyncCallerContext> context =
multi_part_context;
s3_file->s3_client->UploadPartCopyAsync(upload_part_copy_request,
callback, context);
}
// Wait till they finish.
{
absl::MutexLock l(&multi_part_copy_mutex);
// Wait on the mutex until notify is called then check the finished parts
// as there could be false notifications.
while (num_finished_parts != num_parts) {
multi_part_copy_cv.Wait(&multi_part_copy_mutex);
}
}
// check if there was any error for any part.
for (auto part_number = 0; part_number < num_parts; ++part_number) {
if (!etag_outcomes[part_number].IsSuccess()) {
if (retries >= 3) {
AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file,
status);
if (TF_GetCode(status) != TF_OK) return;
return TF_SetStatusFromAWSError(etag_outcomes[part_number].GetError(),
status);
} else {
// Retry.
num_finished_parts--;
}
}
}
}
Aws::S3::Model::CompletedMultipartUpload completed_multipart_upload;
// If there was an error still in any part, it would abort and return in the
// above loop. We set the eTag of completed parts to the final
// `completed_multipart_upload`. Note these parts have to be added in order.
for (int part_number = 0; part_number < num_parts; ++part_number) {
Aws::S3::Model::CompletedPart completed_part;
completed_part.SetPartNumber(part_number + 1);
completed_part.SetETag(etag_outcomes[part_number].GetResult());
completed_multipart_upload.AddParts(completed_part);
}
Aws::S3::Model::CompleteMultipartUploadRequest
complete_multipart_upload_request;
complete_multipart_upload_request.WithBucket(bucket_dst)
.WithKey(object_dst)
.WithUploadId(upload_id)
.WithMultipartUpload(completed_multipart_upload);
auto complete_multipart_upload_outcome =
s3_file->s3_client->CompleteMultipartUpload(
complete_multipart_upload_request);
if (!complete_multipart_upload_outcome.IsSuccess())
AbortMultiPartCopy(bucket_dst, object_dst, upload_id, s3_file, status);
else
return TF_SetStatus(status, TF_OK, "");
if (TF_GetCode(status) == TF_OK)
return TF_SetStatusFromAWSError(
complete_multipart_upload_outcome.GetError(), status);
};
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_Status* status) {
auto file_size = GetFileSize(filesystem, src, status);
if (TF_GetCode(status) != TF_OK) return;
if (file_size == 0)
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Source is a directory or empty file");
Aws::String bucket_src, object_src;
ParseS3Path(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
Aws::String copy_src = bucket_src + "/" + object_src;
Aws::String bucket_dst, object_dst;
ParseS3Path(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto s3_file = static_cast<S3File*>(filesystem->plugin_filesystem);
auto chunk_size =
s3_file->multi_part_chunk_sizes[Aws::Transfer::TransferDirection::UPLOAD];
size_t num_parts = 1;
if (file_size > chunk_size) num_parts = ceil((float)file_size / chunk_size);
if (num_parts == 1)
SimpleCopyFile(copy_src, bucket_dst, object_dst, s3_file, status);
else if (num_parts > 10000)
TF_SetStatus(
status, TF_UNIMPLEMENTED,
absl::StrCat("MultiPartCopy with number of parts more than 10000 is "
"not supported. Your object ",
src, " required ", num_parts,
" as multi_part_copy_part_size is set to ", chunk_size,
". You can control this part size using the environment "
"variable S3_MULTI_PART_COPY_PART_SIZE to increase it.")
.c_str());
else
MultiPartCopy(copy_src, bucket_dst, object_dst, num_parts, file_size,
s3_file, status);
}
// TODO(vnvo2409): Implement later
} // namespace tf_s3_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
const char* uri) {
TF_SetFilesystemVersionMetadata(ops);
ops->scheme = strdup(uri);
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {
info->plugin_memory_allocate = plugin_memory_allocate;
info->plugin_memory_free = plugin_memory_free;
info->num_schemes = 1;
info->ops = static_cast<TF_FilesystemPluginOps*>(
plugin_memory_allocate(info->num_schemes * sizeof(info->ops[0])));
ProvideFilesystemSupportFor(&info->ops[0], "s3");
}

View File

@ -0,0 +1,47 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_
#include <aws/core/Aws.h>
#include <aws/core/utils/StringUtils.h>
#include <aws/core/utils/memory/stl/AWSMap.h>
#include <aws/core/utils/threading/Executor.h>
#include <aws/s3/S3Client.h>
#include <aws/transfer/TransferManager.h>
#include "absl/synchronization/mutex.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
namespace tf_s3_filesystem {
typedef struct S3File {
std::shared_ptr<Aws::S3::S3Client> s3_client;
std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor;
// We need 2 `TransferManager`, for multipart upload/download.
Aws::Map<Aws::Transfer::TransferDirection,
std::shared_ptr<Aws::Transfer::TransferManager>>
transfer_managers;
// Sizes to split objects during multipart upload/download.
Aws::Map<Aws::Transfer::TransferDirection, uint64_t> multi_part_chunk_sizes;
bool use_multi_part_download;
absl::Mutex initialization_lock;
S3File();
} S3File;
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
} // namespace tf_s3_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_S3_S3_FILESYSTEM_H_

View File

@ -0,0 +1,23 @@
# Library of gradient functions.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "math_grad",
srcs = ["math_grad.cc"],
hdrs = [
"math_grad.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -0,0 +1,53 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
using tensorflow::ops::Identity;
namespace tensorflow {
namespace gradients {
namespace {
class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity0"));
(*grad_outputs)[0] = identity_outputs[0];
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
absl::MakeSpan(identity_outputs),
"Identity1"));
(*grad_outputs)[1] = identity_outputs[0];
return Status::OK();
}
~AddGradientFunction() override {}
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
}
} // namespace gradients
} // namespace tensorflow

View File

@ -0,0 +1,26 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#define TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_
#include "tensorflow/c/eager/gradients.h"
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_MATH_GRAD_H_

View File

@ -0,0 +1,24 @@
# Experimental ops. These will eventually be replaced by machine-generated versions.
package(
licenses = ["notice"], # Apache 2.0
)
cc_library(
name = "array_ops",
srcs = [
"array_ops.cc",
],
hdrs = [
"array_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)

View File

@ -0,0 +1,38 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
// Creates an Identity op.
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr identity_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(identity_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
int num_retvals = 1;
return identity_op->Execute(outputs, &num_retvals);
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace ops {
Status Identity(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_

View File

@ -113,8 +113,23 @@ cc_library(
deps = [
":concrete_function",
":saved_model_api",
":saved_model_utils",
"//tensorflow/c:tensor_interface",
"//tensorflow/c/eager:immediate_execution_context",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/experimental/saved_model/core/ops:restore_ops",
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
"//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible",
"//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function",
"//tensorflow/c/experimental/saved_model/core/revived_types:variable",
"//tensorflow/cc/saved_model:bundle_v2",
"//tensorflow/cc/saved_model:constants",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
)

View File

@ -15,47 +15,360 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/tf_saved_model_api.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "absl/algorithm/container.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/ops/restore_ops.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
namespace tensorflow {
// Maps from a FunctionDef's name to FunctionDef, for a given FunctionDefLibrary
using FunctionDefMap =
std::unordered_map<StringPiece, const tensorflow::FunctionDef*,
StringPieceHasher>;
// Maps from a Nodedef's name to its corresponding AttrValues, for a given
// Graphdef
using NodeAttrMap =
std::unordered_map<StringPiece, const AttrValueMap*, StringPieceHasher>;
// Maps from Node ID to an "Revived Object" implementing
// "TensorHandleConvertible"
using RevivedObjectMap =
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>;
// Maps from a functiondef's name to the corresponding "TFConcreteFunction"
using ConcreteFunctionMap =
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>;
namespace {
Status ConstantFromSavedConstant(
ImmediateExecutionContext* ctx,
const tensorflow::SavedConstant& saved_constant,
const NodeAttrMap& node_attr_map, std::unique_ptr<Constant>* output) {
const std::string& const_op_name = saved_constant.operation();
const auto& node_name_and_attrs = node_attr_map.find(const_op_name);
if (node_name_and_attrs == node_attr_map.end()) {
return errors::FailedPrecondition(
"Unable to find Const operation with name'", const_op_name,
"' in SavedModel graphdef");
}
const AttrValueMap* attrs = node_name_and_attrs->second;
const auto& attr_name_and_value = attrs->find("value");
if (attr_name_and_value == attrs->end()) {
return errors::FailedPrecondition("Unable to find Const operation '",
const_op_name, "'s value attribute");
}
const TensorProto& tensor_proto = attr_name_and_value->second.tensor();
return internal::TensorProtoToConstant(ctx, tensor_proto, output);
}
// Restores all non-function objects in the SavedModel's object graph.
// This function walks through the metagraph's saved object graph, and
// constructs revived versions of SavedVariable, SavedConstant, SavedAsset, and
// SavedResources. These are returned via the `out` parameter.
Status ReviveObjects(
const MetaGraphDef& metagraph, ImmediateExecutionContext* context,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>*
revived_objects) {
// This is needed to restore "Constant" nodes by looking up their
// "Value" attribute.
NodeAttrMap node_attr_map = internal::NodeToAttrMap(metagraph.graph_def());
// Iterate through all the saved objects, restoring objects as we go.
// We don't recreate functions until all other objects have been created.
for (int i = 0; i < metagraph.object_graph_def().nodes_size(); ++i) {
const SavedObject& node = metagraph.object_graph_def().nodes(i);
if (node.kind_case() == SavedObject::kVariable) {
std::unique_ptr<Variable> variable;
TF_RETURN_IF_ERROR(
internal::LoadSavedVariable(context, node.variable(), &variable));
(*revived_objects)[i] = std::move(variable);
} else if (node.kind_case() == SavedObject::kConstant) {
std::unique_ptr<Constant> constant;
TF_RETURN_IF_ERROR(ConstantFromSavedConstant(context, node.constant(),
node_attr_map, &constant));
(*revived_objects)[i] = std::move(constant);
} else if (node.kind_case() == SavedObject::kAsset) {
// TODO(bmzhao): Implement Asset C++ class. This should be just recreating
// the full path to the asset file:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/saved_model/load.py#L395-L396
// and storing it as a string tensor:
// https://github.com/tensorflow/tensorflow/blob/6a0bdbdb7c48a3491ae1277083ae3dafb4ab4d7a/tensorflow/python/training/tracking/tracking.py#L324-L325
return errors::Unimplemented("SavedAsset loading is not implemented yet");
} else if (node.kind_case() == SavedObject::kResource) {
// TODO(bmzhao): Figure out how resource loading works and implement it
return errors::Unimplemented(
"SavedResource loading is not implemented yet");
}
}
return Status();
}
Status ReviveFunctions(const MetaGraphDef& metagraph,
const RevivedObjectMap& revived_objects,
ImmediateExecutionContext* context,
ConcreteFunctionMap* restored_functions) {
const FunctionDefMap function_def_map =
internal::FunctionNameToFunctionDefMap(metagraph.graph_def().library());
// Iterate through all objects, only examining functions.
for (const SavedObject& node : metagraph.object_graph_def().nodes()) {
if (node.kind_case() == SavedObject::kBareConcreteFunction) {
const std::string& function_name =
node.bare_concrete_function().concrete_function_name();
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
} else if (node.kind_case() == SavedObject::kFunction) {
// We only allow loading functions that have an annotated input signature,
// which means there is 1:1 correspondence between tf.function
// <=> SavedFunction <=> SavedConcreteFunction <=> FunctionDef. This is
// the same restriction that MLIR has:
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2677-L2707
const SavedFunction& saved_function = node.function();
if (saved_function.concrete_functions_size() != 1) {
return errors::FailedPrecondition(
"Only tf.functions annotated with an input signature are supported "
"by SavedModelAPI. This means that there should only be a single "
"ConcreteFunction per tf.function");
}
const std::string& function_name = saved_function.concrete_functions(0);
const SavedConcreteFunction& saved_concrete_function =
metagraph.object_graph_def().concrete_functions().at(function_name);
const FunctionDef* function_def = function_def_map.at(function_name);
std::unique_ptr<TFConcreteFunction> concrete_function;
TF_RETURN_IF_ERROR(internal::LoadTFConcreteFunction(
saved_concrete_function, function_def, revived_objects, context,
&concrete_function));
(*restored_functions)[function_name] = std::move(concrete_function);
}
}
return Status();
}
const TrackableObjectGraph::TrackableObject::SerializedTensor*
FindSerializedTensorInTrackable(
const TrackableObjectGraph::TrackableObject& trackable_object,
absl::string_view name) {
for (const auto& maybe_serialized_tensor : trackable_object.attributes()) {
if (maybe_serialized_tensor.name() == name) {
return &maybe_serialized_tensor;
}
}
return nullptr;
}
// This function reads the Checkpoint embedded in the SavedModel, and calls the
// appropriate Restore ops on each of the variables.
// Note(bmzhao): Conceptually, objects that contain checkpointable state
// implement the "_gather_saveables_for_checkpoint" method
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/tracking/base.py#L953-L983
// which returns a dict of string key -> EITHER:
// 1. python callable (taking a checkpoint key) returning SaveableObject OR
// 2. variable (partitioned/resource/reference or otherwise)
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L58.
// The string key becomes the "name" attribute of the SerializedTensor proto
// in the TrackableObjectGraph,
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/core/protobuf/trackable_object_graph.proto#L26
// And the checkpoint_key is a globally unique string derived from this name:
// https://github.com/tensorflow/tensorflow/blob/842df9e6b516e42578a8d23b35d41176b9a6cf1d/tensorflow/python/training/tracking/graph_view.py#L236-L241
// SaveableObjects model the information needed to pass to the SaveV2/RestoreV2
// ops via their SaveSpec members
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L21,
// which contain the "real" checkpoint keys into the TensorBundle SSTable.
// They also contain the logic needed to take the restored tensors from
// RestoreV2 and load them back into the "object" they came from via their
// overridden "restore" method:
// https://github.com/tensorflow/tensorflow/blob/ddc1bbad3dfd4a089eb96014f26cc16664b1b2f8/tensorflow/python/training/saving/saveable_object.py#L85
Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
const RevivedObjectMap& revived_objects,
const std::string& directory,
ImmediateExecutionContext* context) {
// TODO(bmzhao): Batch up all the restores into a single restore op per
// device, following logic in MultiDeviceSaver.
TF_RETURN_IF_ERROR(bundle->VisitObjectsToRestore(
[&revived_objects, &directory, context, bundle](
int node, const TrackableObjectGraph::TrackableObject& trackable) {
if (bundle->saved_object_graph().nodes(node).kind_case() !=
SavedObject::kVariable) {
// TODO(bmzhao): This requires using the newly added Save/Restore
// functions from
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
return errors::Unimplemented(
"Restoring non-variable objects has not been implemented yet. ");
}
Variable* variable =
down_cast<Variable*>(revived_objects.at(node).get());
// Restore the tensor's value from the checkpoint
const TrackableObjectGraph::TrackableObject::SerializedTensor*
attribute =
FindSerializedTensorInTrackable(trackable, "VARIABLE_VALUE");
if (attribute == nullptr) {
return errors::FailedPrecondition(
"Could not find SerializedTensor with name VARIABLE_VALUE for "
"saved variable");
}
const std::string& checkpoint_key = attribute->checkpoint_key();
std::string variables_path_prefix =
io::JoinPath(directory, kSavedModelVariablesDirectory,
kSavedModelVariablesFilename);
ImmediateTensorHandlePtr restored_output;
TF_RETURN_IF_ERROR(internal::SingleRestore(
context, variables_path_prefix, checkpoint_key, variable->dtype(),
&restored_output));
// Assign the restored tensor's value to the variable
return variable->Assign(restored_output.get());
}));
return Status();
}
} // namespace
Status TFSavedModelAPI::GetFunction(const std::string& function_path,
ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a function.
return errors::Unimplemented(
"Retrieving functions is unimplemented currently");
const SavedObject* object =
internal::FindNodeAtPath(function_path, bundle_.saved_object_graph());
if (object == nullptr) {
return errors::NotFound("No saved object found at path ", function_path);
}
if (object->kind_case() == SavedObject::kBareConcreteFunction) {
*function =
concrete_functions_
.at(object->bare_concrete_function().concrete_function_name())
.get();
} else if (object->kind_case() == SavedObject::kFunction) {
*function =
concrete_functions_.at(object->function().concrete_functions(0)).get();
} else {
return errors::InvalidArgument(function_path,
" is not a path to a Function.");
}
return Status();
}
Status TFSavedModelAPI::GetSignatureDefFunction(
const std::string& signature_def_key, ConcreteFunction** function) {
// TODO(bmzhao): Add support for retrieving a signaturedef function.
return errors::Unimplemented(
"Retrieving functions is unimplemented currently");
"Retrieving SignatureDef functions is unimplemented currently");
}
std::vector<ConcreteFunction*> TFSavedModelAPI::ListFunctions() {
std::vector<ConcreteFunction*> result;
result.reserve(functions_.size());
for (ConcreteFunction& function : functions_) {
result.push_back(&function);
result.reserve(concrete_functions_.size());
for (auto& index_and_function : concrete_functions_) {
result.push_back(index_and_function.second.get());
}
return result;
}
TFSavedModelAPI::TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions)
: directory_(directory),
bundle_(std::move(bundle)),
revived_objects_(std::move(revived_objects)),
concrete_functions_(std::move(concrete_functions)) {}
Status TFSavedModelAPI::Load(
const std::string& directory,
const absl::optional<std::unordered_set<std::string>>& tags,
ImmediateExecutionContext* context, std::unique_ptr<TFSavedModelAPI>* out) {
// TODO(bmzhao): Add support for loading a TFSavedModelImpl.
return errors::Unimplemented(
"TFSavedModelAPIImpl loading is unimplemented currently");
// TODO(bmzhao): Add support for loading a TF1 SavedModel.
if (tags) {
return errors::Unimplemented(
"Loading saved models with explicit tags will be supported in the "
"future");
}
SavedModelV2Bundle bundle;
TF_RETURN_IF_ERROR(SavedModelV2Bundle::Load(directory, &bundle));
// TODO(bmzhao): Mangle loaded function names so that different
// models loaded in the same runtime Context don't clobber eachother.
// This occurs in python here:
// https://github.com/tensorflow/tensorflow/blob/285b5fa15405c5e2c084080f52a1818be8648079/tensorflow/python/saved_model/function_deserialization.py#L438-L454
RevivedObjectMap revived_objects;
TF_RETURN_IF_ERROR(
ReviveObjects(bundle.meta_graph_def(), context, &revived_objects));
// TODO(bmzhao): When we later add support for loading resources, we need to
// handle the case where materializing a function's captures requires invoking
// other functions. This occurs when retrieving the resource handle for a
// TrackableResource:
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/saved_model/load.py#L240
// https://github.com/tensorflow/tensorflow/blob/f19c6efb4a8ba60e2492eedc98ef5375abb39dc7/tensorflow/python/training/tracking/tracking.py#L233
// This requires restoring functions in a topological sort order by capture
// dependencies.
ConcreteFunctionMap function_map;
TF_RETURN_IF_ERROR(ReviveFunctions(bundle.meta_graph_def(), revived_objects,
context, &function_map));
TF_RETURN_IF_ERROR(
RestoreCheckpoint(&bundle, revived_objects, directory, context));
out->reset(new TFSavedModelAPI(directory, std::move(bundle),
std::move(revived_objects),
std::move(function_map)));
return Status();
}
} // namespace tensorflow

View File

@ -16,14 +16,19 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_SAVED_MODEL_IMPL_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "absl/types/optional.h"
#include "tensorflow/c/eager/immediate_execution_context.h"
#include "tensorflow/c/experimental/saved_model/core/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h"
#include "tensorflow/c/experimental/saved_model/core/saved_model_api.h"
#include "tensorflow/cc/saved_model/bundle_v2.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
@ -63,8 +68,19 @@ class TFSavedModelAPI : public SavedModelAPI {
~TFSavedModelAPI() override = default;
private:
TFSavedModelAPI() = default;
std::vector<ConcreteFunction> functions_;
TFSavedModelAPI(
const std::string& directory, SavedModelV2Bundle bundle,
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects,
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions);
std::string directory_;
SavedModelV2Bundle bundle_;
std::unordered_map<int, std::unique_ptr<TensorHandleConvertible>>
revived_objects_;
std::unordered_map<std::string, std::unique_ptr<TFConcreteFunction>>
concrete_functions_;
};
} // namespace tensorflow

View File

@ -16,10 +16,15 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include <string>
#include <vector>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/platform/test.h"
@ -92,12 +97,51 @@ TEST_P(CSavedModelAPITest, LoadsSavedModel) {
TF_SavedModel* saved_model =
TF_LoadSavedModel(model_dir.c_str(), ctx, status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(TF_GetCode(status), TF_UNIMPLEMENTED);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_ConcreteFunction* compute_fn =
TF_GetSavedModelConcreteFunction(saved_model, "compute", status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* compute_fn_op = TF_ConcreteFunctionGetCallOp(compute_fn, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
const TF_TensorHandleList* captures =
TF_ConcreteFunctionGetCaptures(compute_fn);
// TODO(bmzhao): Finish API on FunctionMetadata args, so we know how many
// inputs + outputs a function has.
std::vector<TFE_TensorHandle*> compute_fn_inputs;
TFE_TensorHandle* input_a = TestScalarTensorHandle(ctx, 2.0f);
TFE_TensorHandle* input_b = TestScalarTensorHandle(ctx, 1.0f);
compute_fn_inputs.reserve(2 + TF_TensorHandleListSize(captures));
compute_fn_inputs.push_back(input_a);
compute_fn_inputs.push_back(input_b);
for (int i = 0; i < TF_TensorHandleListSize(captures); ++i) {
compute_fn_inputs.push_back(TF_TensorHandleListGet(captures, i));
}
TFE_OpAddInputList(compute_fn_op, compute_fn_inputs.data(),
compute_fn_inputs.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* compute_fn_outputs[1] = {nullptr};
int num_retvals = 1;
TFE_Execute(compute_fn_op, &compute_fn_outputs[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* result = TFE_TensorHandleResolve(compute_fn_outputs[0], status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
EXPECT_EQ(TF_NumDims(result), 0);
float output_value = *static_cast<float*>(TF_TensorData(result));
// (1 + 2) * (2 + 1) / 3 + 5 should be 8
EXPECT_FLOAT_EQ(output_value, 8.0);
TF_DeleteTensor(result);
TFE_DeleteTensorHandle(compute_fn_outputs[0]);
TFE_DeleteTensorHandle(input_a);
TFE_DeleteTensorHandle(input_b);
TFE_DeleteOp(compute_fn_op);
TF_DeleteSavedModel(saved_model);
TF_DeleteStatus(status);
TFE_DeleteContext(ctx);

View File

@ -86,11 +86,7 @@ TEST_P(CPPSavedModelAPITest, LoadsSavedModel) {
std::unique_ptr<SavedModelAPI> model =
SavedModelAPI::Load(model_dir, *runtime, &status);
// TODO(bmzhao): Change this to expect TF_OK when loading is implemented.
// That unblocks writing other tests that require a TF_SavedModel*,
// like loading a ConcreteFunction. This test at least checks that the
// C API builds and can be minimally run.
EXPECT_EQ(status.code(), TF_UNIMPLEMENTED) << status.message();
EXPECT_EQ(status.code(), TF_OK) << status.message();
}
INSTANTIATE_TEST_SUITE_P(RuntimeAgnosticCPPSavedModelTests,

View File

@ -648,11 +648,11 @@ cc_library(
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:bounds_check",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
@ -677,11 +677,11 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_bounds_check",
"//tensorflow/core:framework_internal",
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/framework:bounds_check",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -277,7 +277,8 @@ static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
absl::Span<VariableInfo const> variable_infos,
absl::Span<const int> constants, bool lazy, xla::LocalClient** client,
absl::Span<const int> constants, bool lazy, bool may_alias_resource_update,
xla::LocalClient** client,
const XlaCompiler::CompilationResult** compilation_result,
xla::LocalExecutable** executable) {
// We store information about the JIT-compiled XLA computation
@ -332,6 +333,9 @@ static Status CompileToLocalExecutable(
// Optimization: where possible, have the computation return a naked array
// rather than a one-element tuple.
compile_options.always_return_tuple = false;
compile_options.alias_resource_update = !has_ref_vars &&
!platform_info.is_on_xla_device() &&
may_alias_resource_update;
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
@ -350,20 +354,22 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* compilation_result;
xla::LocalExecutable* executable;
ResourceVarsSnapshot variables_snapshot;
std::vector<VariableInfo> variable_infos;
{
std::vector<VariableInfo> variable_infos;
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
Status s = CompileToLocalExecutable(
ctx, function_, /*has_ref_vars=*/has_ref_vars_, platform_info_,
variable_infos, constants_, /*lazy=*/false, &client,
&compilation_result, &executable);
variable_infos, constants_, /*lazy=*/false,
/*may_alias_resource_update=*/true, &client, &compilation_result,
&executable);
OP_REQUIRES_OK(ctx, s);
OP_REQUIRES_OK(ctx,
SnapshotResourceVariables(ctx, resources_, variable_infos,
&variables_snapshot));
}
std::map<int, const Tensor*> resource_var_ptrs;
for (int i = 0; i < resources_.size(); i++) {
resource_var_ptrs[resources_[i]] = variable_infos[i].var()->tensor();
}
se::Stream* stream =
@ -374,12 +380,19 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
int device_ordinal = stream ? stream->parent()->device_ordinal()
: client->default_device_ordinal();
XlaComputationLaunchContext launch_context(
client, allocator,
client, allocator, device_ordinal,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
platform_info_.UseMultipleStreams());
launch_context.PopulateInputs(ctx, compilation_result, variables_snapshot,
/*missing_ctx_input_prefix=*/0);
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(ctx, compilation_result, resource_var_ptrs,
/*missing_ctx_input_prefix=*/0,
input_output_alias);
OP_REQUIRES_OK(ctx, execution_inputs.status());
// Execute the computation.
VLOG(2) << "Executing computation.";
@ -403,24 +416,24 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
Env* env = Env::Default();
auto start_time = env->NowMicros();
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
xla::StatusOr<xla::ExecutionOutput> execution_output;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result = executable->Run(launch_context.arguments(), run_options);
execution_output =
executable->Run(std::move(*execution_inputs), run_options);
} else {
run_result = executable->RunAsync(launch_context.arguments(), run_options);
execution_output =
executable->RunAsync(std::move(*execution_inputs), run_options);
}
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
OP_REQUIRES_OK(
ctx, launch_context.PopulateOutputs(
ctx, compilation_result, execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(variable_infos),
input_output_alias, resource_var_ptrs));
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
OP_REQUIRES_OK(ctx,
launch_context.PopulateOutputs(
ctx, compilation_result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias,
variables_snapshot));
VLOG(1) << "Done";
}
@ -516,10 +529,14 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES_OK(
ctx, GetVariableInfosFromCtxInputs(ctx, resources_, &variable_infos));
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(variable_infos)));
// Do not alias resource updates as locking variables in XlaCompile and
// unlocking them in XlaRun may lead to deadlocks.
Status status = CompileToLocalExecutable(
ctx, function_, has_ref_vars_, platform_info_, variable_infos,
constants_,
/*lazy=*/!must_compile_, &client, &kernel, &executable);
/*lazy=*/!must_compile_,
/*may_alias_resource_update=*/false, &client, &kernel, &executable);
OP_REQUIRES_OK(ctx, SnapshotResourceVariables(ctx, resources_,
variable_infos, &variables));
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
@ -587,14 +604,22 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
int device_ordinal = stream ? stream->parent()->device_ordinal()
: closure.client()->default_device_ordinal();
XlaComputationLaunchContext launch_context(
closure.client(), allocator,
closure.client(), allocator, device_ordinal,
/*allocate_xla_tensors=*/platform_info_.is_on_xla_device(),
/*use_multiple_streams=*/platform_info_.UseMultipleStreams());
// We're missing the must-be-constant inputs, tell `PopulateInputs`
// about this. We don't actually need these inputs because they've
// already been baked into the compiled kernel.
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs;
std::map<int, const Tensor*> snapshot_ptrs;
{
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
@ -604,13 +629,17 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
},
tensorflow::profiler::TraceMeLevel::kInfo);
launch_context.PopulateInputs(
ctx, closure.compilation_result(), closure.resource_var_snapshots(),
/*missing_ctx_input_prefix=*/closure.num_constant_args());
for (auto& p : closure.resource_var_snapshots()) {
snapshot_ptrs.emplace(p.first,
p.second.has_value() ? &p.second.value() : nullptr);
}
execution_inputs = launch_context.PopulateInputs(
ctx, closure.compilation_result(), snapshot_ptrs,
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias);
OP_REQUIRES_OK(ctx, execution_inputs.status());
}
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(allocator);
@ -631,21 +660,19 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
Env* env = Env::Default();
auto start_time = env->NowMicros();
xla::StatusOr<xla::ScopedShapedBuffer> run_result;
xla::StatusOr<xla::ExecutionOutput> execution_output;
if (!stream || platform_info_.platform_id() == se::host::kHostPlatformId) {
run_result =
closure.executable()->Run(launch_context.arguments(), run_options);
execution_output =
closure.executable()->Run(std::move(*execution_inputs), run_options);
} else {
run_result =
closure.executable()->RunAsync(launch_context.arguments(), run_options);
execution_output = closure.executable()->RunAsync(
std::move(*execution_inputs), run_options);
}
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
OP_REQUIRES(ctx, execution_output.ok(), execution_output.status());
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
const xla::HloInputOutputAliasConfig& input_output_alias =
closure.executable()->executable()->module().input_output_alias_config();
tensorflow::profiler::TraceMe hlo_module_activity(
[&] {
@ -653,12 +680,16 @@ void XlaRunOp::Compute(OpKernelContext* ctx) {
},
tensorflow::profiler::TraceMeLevel::kInfo);
xla::StatusOr<std::vector<VariableInfo>> variable_infos = GatherVariableInfo(
ctx, *closure.compilation_result(), closure.num_constant_args());
OP_REQUIRES_OK(ctx, variable_infos.status());
OP_REQUIRES_OK(ctx, LockVariables(absl::MakeSpan(*variable_infos)));
OP_REQUIRES_OK(
ctx,
launch_context.PopulateOutputs(
ctx, closure.compilation_result(), run_result.ConsumeValueOrDie(),
ctx, closure.compilation_result(), execution_output->ConsumeResult(),
/*missing_ctx_input_prefix=*/closure.num_constant_args(),
input_output_alias, closure.resource_var_snapshots()));
absl::MakeSpan(*variable_infos), input_output_alias, snapshot_ptrs));
}
XlaMergeOp::XlaMergeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}

View File

@ -1829,7 +1829,7 @@ TEST(XlaCompilationTest, XLALiteAllowlist) {
}
EXPECT_TRUE(unknow_op.empty())
<< "Someone added support for a new TF opeations inside XLA. They must "
"be included in the XLALite allowlist or blacklist:\n"
"be included in the XLALite allowlist or denylist:\n"
<< absl::StrJoin(unknow_op, "\n");
}
} // namespace

View File

@ -50,35 +50,47 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
// Builds an XLA allocator for the device.
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(),
client->default_device_ordinal(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
launch_context.PopulateInputs(ctx, result, variable_args,
/*missing_ctx_input_prefix=*/0);
std::map<int, const Tensor*> snapshot_ptrs;
for (auto& p : variable_args) {
snapshot_ptrs.emplace(p.first,
p.second.has_value() ? &p.second.value() : nullptr);
}
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::StatusOr<std::vector<xla::ExecutionInput>> execution_inputs =
launch_context.PopulateInputs(ctx, result, snapshot_ptrs,
/*missing_ctx_input_prefix=*/0,
input_output_alias);
TF_RETURN_IF_ERROR(execution_inputs.status());
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
VLOG(2) << "Executing computation: " << name();
for (const xla::ShapedBuffer* arg : launch_context.arguments()) {
VLOG(2) << name() << ": " << *arg;
}
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
xla::StatusOr<xla::ScopedShapedBuffer> run_result =
executable->Run(launch_context.arguments(), run_options);
xla::StatusOr<xla::ExecutionOutput> run_result =
executable->Run(execution_inputs.ConsumeValueOrDie(), run_options);
TF_RETURN_IF_ERROR(run_result.status());
const xla::HloInputOutputAliasConfig& input_output_alias =
executable->executable()->module().input_output_alias_config();
xla::ExecutionOutput execution_output = run_result.ConsumeValueOrDie();
xla::StatusOr<std::vector<VariableInfo>> variable_infos =
GatherVariableInfo(ctx, *result, 0);
TF_RETURN_IF_ERROR(variable_infos.status());
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(*variable_infos)));
TF_RETURN_IF_ERROR(launch_context.PopulateOutputs(
ctx, result, run_result.ConsumeValueOrDie(),
/*missing_ctx_input_prefix=*/0, input_output_alias, variable_args));
ctx, result, execution_output.ConsumeResult(),
/*missing_ctx_input_prefix=*/0, absl::MakeSpan(*variable_infos),
input_output_alias, snapshot_ptrs));
return Status::OK();
}

View File

@ -59,11 +59,13 @@ void XlaAssignVariableOp::Compute(OpKernelContext* context) {
return Status::OK();
}));
mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
OP_REQUIRES(
context,
!variable->is_initialized || variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
variable->is_initialized = true;
*variable->tensor() = value;
}

View File

@ -91,29 +91,19 @@ VariableInfo::~VariableInfo() {
Status GetVariableInfosFromCtxInputs(OpKernelContext* ctx,
absl::Span<const int> variable_indices,
std::vector<VariableInfo>* result) {
std::vector<const ResourceHandle*> resource_handles;
absl::c_transform(
variable_indices, std::back_inserter(resource_handles),
[&](int variable_idx) { return &HandleFromInput(ctx, variable_idx); });
std::vector<core::RefCountPtr<Var>> variables;
Status s = LookupResources(ctx, resource_handles, &variables);
if (!s.ok()) {
errors::AppendToMessage(&s, kPossibleNonVariableResourceHintMessage);
return s;
}
result->clear();
result->reserve(variable_indices.size());
for (int i = 0; i < variable_indices.size(); i++) {
// *Release* the variable because we're going to unref it later in
// ~VariableInfo.
Var* variable = variables[i].release();
int input_idx = variable_indices[i];
std::string var_name = HandleFromInput(ctx, input_idx).name();
result->emplace_back(input_idx, var_name, variable);
for (int var_idx : variable_indices) {
Var* variable = nullptr;
ResourceHandle handle = HandleFromInput(ctx, var_idx);
TF_RETURN_IF_ERROR(
LookupOrCreateResource<Var>(ctx, handle, &variable, [&](Var** ptr) {
// This var is uninitialized for now.
*ptr = new Var(DT_INVALID);
return Status::OK();
}));
result->emplace_back(var_idx, handle.name(), variable);
}
return Status::OK();
}
@ -176,24 +166,43 @@ Status SnapshotResourceVariables(OpKernelContext* ctx,
XlaComputationLaunchContext::XlaComputationLaunchContext(
xla::LocalClient* client, se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors, bool use_multiple_streams)
int device_ordinal, bool allocate_xla_tensors, bool use_multiple_streams)
: client_(client),
xla_allocator_(xla_allocator),
allocate_xla_tensors_(allocate_xla_tensors),
use_multiple_streams_(use_multiple_streams) {
use_multiple_streams_(use_multiple_streams),
device_ordinal_(device_ordinal) {
if (use_multiple_streams_) {
CHECK(allocate_xla_tensors_) << "To use multiple streams correctly we must "
"be allocating XLA tensors!";
}
}
void XlaComputationLaunchContext::PopulateInputs(
// Fills in `execution_input` with `buffer` for `index`.
static void PopulateExecutionInputBuffer(xla::ExecutionInput& execution_input,
xla::ShapeIndex index,
se::DeviceMemoryBase& buffer,
bool donate_buffer, int device_ordinal,
se::DeviceMemoryAllocator* allocator) {
xla::MaybeOwningDeviceMemory* in_buffer =
execution_input.MutableBuffer(index);
if (donate_buffer) {
*in_buffer = se::OwningDeviceMemory(buffer, device_ordinal, allocator);
buffer = se::DeviceMemoryBase();
} else {
*in_buffer = buffer;
}
}
xla::StatusOr<std::vector<xla::ExecutionInput>>
XlaComputationLaunchContext::PopulateInputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables, int missing_ctx_input_prefix) {
// Build ShapedBuffers that point directly to the Tensor buffers.
arg_ptrs_ =
std::vector<ShapedBuffer*>(compilation_result->xla_input_shapes.size());
const std::map<int, const Tensor*>& resource_vars,
int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias) {
std::vector<xla::ExecutionInput> arguments;
arguments.reserve(compilation_result->xla_input_shapes.size());
xla::TransferManager* transfer_manager =
client_->backend().transfer_manager();
@ -201,10 +210,28 @@ void XlaComputationLaunchContext::PopulateInputs(
int arg_num = compilation_result->input_mapping[i];
CHECK_GE(arg_num, missing_ctx_input_prefix);
const xla::Shape& shape = compilation_result->xla_input_shapes[i];
const Tensor* t = variables.count(arg_num)
? &(variables.at(arg_num).value())
const xla::Shape& device_shape =
transfer_manager->HostShapeToDeviceShape(shape);
bool is_resource_variable = resource_vars.count(arg_num);
bool is_updated_resource_variable =
is_resource_variable &&
absl::c_any_of(compilation_result->resource_updates,
[&](const XlaCompiler::ResourceUpdate& update) {
return update.input_index == i && update.modified;
});
const Tensor* t = is_resource_variable
? resource_vars.at(arg_num)
: &(ctx->input(arg_num - missing_ctx_input_prefix));
CHECK(t);
bool donate_buffer =
t->RefCountIsOne() && is_updated_resource_variable &&
input_output_alias.ParameterHasAlias(i, xla::ShapeIndex{});
VLOG(3) << "Processing input: " << i
<< "; is_resource_variable=" << is_resource_variable
<< "; is_updated_resource_variable=" << is_updated_resource_variable
<< "; donate_buffer=" << donate_buffer;
if (use_multiple_streams_) {
CHECK(ctx->op_device_context() && ctx->op_device_context()->stream())
@ -215,23 +242,28 @@ void XlaComputationLaunchContext::PopulateInputs(
ctx->op_device_context()->stream());
}
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(
shape, transfer_manager->HostShapeToDeviceShape(shape))) {
arguments.emplace_back(device_shape, shape);
xla::ExecutionInput& execution_input = arguments.back();
if (xla::Shape::Equal().MinorToMajorOnlyInLayout()(shape, device_shape)) {
se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_.emplace_back(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
arg_buffers_.back().set_buffer(dmem, /*index=*/{});
arg_ptrs_[i] = &arg_buffers_.back();
PopulateExecutionInputBuffer(execution_input, xla::ShapeIndex{}, dmem,
donate_buffer, device_ordinal_,
xla_allocator_);
} else {
const XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
XlaTensor* xla_tensor = XlaTensor::FromTensor(t);
CHECK(xla_tensor && xla_tensor->has_shaped_buffer());
arg_ptrs_[i] = const_cast<ShapedBuffer*>(&xla_tensor->shaped_buffer());
xla_tensor->shaped_buffer().buffers().ForEachMutableElement(
[&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
PopulateExecutionInputBuffer(execution_input, index, *buffer,
donate_buffer, device_ordinal_,
xla_allocator_);
});
}
}
return std::move(arguments);
}
// Construct the tensor for given type and buffer.
// Construct the tensor for the given type and buffer.
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
se::DeviceMemoryBase buffer, Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
@ -247,28 +279,26 @@ static Tensor GetOrCreateTensorForOutput(
int output_num, OpKernelContext* ctx, int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias,
absl::Span<const int> input_mapping,
const ResourceVarsSnapshot& resource_var_snapshots, DataType output_dtype,
const TensorShape& output_shape, se::DeviceMemoryBase output_buffer,
Allocator* output_allocator) {
const std::map<int, const Tensor*>& resource_vars_snapshots,
DataType output_dtype, const TensorShape& output_shape,
se::DeviceMemoryBase output_buffer, Allocator* output_allocator) {
xla::ShapeIndex output_index = input_output_alias.shape().IsTuple()
? xla::ShapeIndex({output_num})
: xla::ShapeIndex({});
CHECK(input_output_alias.shape().IsTuple() || output_num == 0);
if (absl::optional<xla::HloInputOutputAliasConfig::Alias> alias =
input_output_alias.GetAliasedParameter(output_index)) {
VLOG(3) << "Found alias: " << alias->ToString();
int tf_param =
input_mapping[alias->parameter_number] - missing_ctx_input_prefix;
const Tensor* input_tensor = &ctx->input(tf_param);
// If input tensor is a resource variable, alias to the snapshot we took at
// entry time.
if (input_tensor->dtype() == DT_RESOURCE) {
const absl::optional<Tensor>& v =
resource_var_snapshots.at(missing_ctx_input_prefix + tf_param);
CHECK(v.has_value());
return *v;
const Tensor input_tensor =
ctx->input(tf_param).dtype() != DT_RESOURCE
? ctx->input(tf_param)
: *resource_vars_snapshots.at(missing_ctx_input_prefix + tf_param);
if (output_buffer.opaque() == input_tensor.data()) {
return input_tensor;
}
return *input_tensor;
}
return MakeTensor(output_dtype, output_shape, output_buffer,
output_allocator);
@ -291,12 +321,10 @@ static Status SetOutputForConstant(
OpKernelContext* ctx, se::Stream* stream,
const XlaCompiler::CompilationResult* compilation_result, int output_num) {
CHECK(compilation_result->outputs[output_num].is_constant);
// Output is a constant.
const Tensor& const_tensor =
compilation_result->outputs[output_num].constant_value;
Tensor* output_tensor;
const size_t total_bytes = const_tensor.TotalBytes();
if (stream && total_bytes > 0) {
if (stream && const_tensor.TotalBytes() > 0) {
// Copy host -> device. (Empty tensors don't have backing buffers.)
// Manually allocate memory using an XlaTensorBuffer so we can allocate
// as much memory as the device requires (as given by
@ -335,52 +363,55 @@ static Status SetOutputForConstant(
return Status::OK();
}
// Creates a list of updates resource variables.
static xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> variable_infos;
variable_infos.reserve(compilation_result->resource_updates.size());
static xla::StatusOr<Var*> GetOrCreateResourceVar(
OpKernelContext* ctx, const ResourceHandle& handle,
const XlaCompiler::ResourceUpdate& write) {
Var* variable = nullptr;
TF_RETURN_IF_ERROR(
LookupOrCreateResource<Var>(ctx, handle, &variable, [&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
return variable;
}
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult& compilation_result,
int missing_ctx_input_prefix) {
std::vector<VariableInfo> out;
out.reserve(compilation_result.resource_updates.size());
for (int i = 0; i < compilation_result.resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
compilation_result.resource_updates[i];
int actual_input_index = write.input_index - missing_ctx_input_prefix;
if (actual_input_index < 0 || actual_input_index >= ctx->num_inputs()) {
return errors::Internal("Invalid input index for variable write.");
}
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
// not a Tensor.
Var* variable = nullptr;
const ResourceHandle handle = HandleFromInput(ctx, actual_input_index);
TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(ctx, handle, &variable,
[&write](Var** ptr) {
*ptr = new Var(write.type);
return Status::OK();
}));
variable_infos.emplace_back(actual_input_index, handle.name(), variable);
TF_ASSIGN_OR_RETURN(Var * variable,
GetOrCreateResourceVar(ctx, handle, write));
out.emplace_back(actual_input_index, handle.name(), variable);
}
return variable_infos;
return std::move(out);
}
Status XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
ScopedShapedBuffer output, int missing_ctx_input_prefix,
absl::Span<VariableInfo> variable_infos,
const xla::HloInputOutputAliasConfig& input_output_alias,
const ResourceVarsSnapshot& resource_var_snapshots) {
const std::map<int, const Tensor*>& resource_vars) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
Allocator* allocator = ctx->device()->GetAllocator({});
// Computation output should always be a tuple.
if (VLOG_IS_ON(2)) {
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
}
VLOG(2) << "Result tuple shape: " << output.on_host_shape().DebugString();
VLOG(2) << "Result tuple shape (on device): "
<< output.on_device_shape().DebugString();
CHECK_EQ(ctx->num_outputs(), compilation_result->outputs.size());
// If the on-host-shape isn't a tuple, create a new single-element tuple
@ -438,8 +469,8 @@ Status XlaComputationLaunchContext::PopulateOutputs(
for (int i = 0; i < ctx->num_outputs(); ++i) {
const TensorShape& shape = output_tensor_shapes[i];
const DataType& type = compilation_result->outputs[i].type;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
<< DataTypeString(type);
VLOG(2) << "Populating output for retval " << i << " shape "
<< shape.DebugString() << " type " << DataTypeString(type);
if (type == DT_VARIANT) {
return errors::Unimplemented(
"Support for TensorList crossing the XLA/TF boundary "
@ -467,30 +498,37 @@ Status XlaComputationLaunchContext::PopulateOutputs(
se::DeviceMemoryBase buffer = output.buffer({output_num});
Tensor output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots,
compilation_result->input_mapping, resource_vars,
ctx->expected_output_dtype(i), shape, buffer, allocator);
output.set_buffer(se::OwningDeviceMemory(), {output_num});
ctx->set_output(i, output_tensor);
}
output.set_buffer(se::OwningDeviceMemory(), {output_num});
++output_num;
}
if (VLOG_IS_ON(3)) {
VLOG(3) << ctx->mutable_output(i)->DeviceSafeDebugString();
}
}
// Apply variable updates, if any.
VLOG(2) << "Applying variable updates";
TF_ASSIGN_OR_RETURN(
std::vector<VariableInfo> variable_infos,
GatherVariableInfo(ctx, compilation_result, missing_ctx_input_prefix));
TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
// input_index -> index into variable_infos.
absl::flat_hash_map<int, int> variable_info_lookup;
for (int i = 0; i < variable_infos.size(); i++) {
variable_info_lookup.emplace(variable_infos[i].index(), i);
}
// Apply variable updates, if any.
for (int i = 0; i < compilation_result->resource_updates.size(); ++i) {
const XlaCompiler::ResourceUpdate& write =
compilation_result->resource_updates[i];
if (variable_infos[i].var()->tensor()->dtype() != write.type) {
int actual_input_index = write.input_index - missing_ctx_input_prefix;
CHECK_GE(actual_input_index, 0);
CHECK_LT(actual_input_index, ctx->num_inputs());
Var* var = variable_infos[variable_info_lookup[actual_input_index]].var();
CHECK(var);
VLOG(2) << "Updating variable #" << i
<< " at input index: " << actual_input_index << " with shape "
<< write.shape.DebugString() << "; variable tensor has shape: "
<< var->tensor()->shape().DebugString();
if (var->is_initialized && var->tensor()->dtype() != write.type) {
return errors::Internal("Mismatched type in variable write");
}
@ -504,14 +542,14 @@ Status XlaComputationLaunchContext::PopulateOutputs(
}
} else {
se::DeviceMemoryBase buffer = output.buffer({output_num});
output.set_buffer(se::OwningDeviceMemory(), {output_num});
output_tensor = GetOrCreateTensorForOutput(
output_num, ctx, missing_ctx_input_prefix, input_output_alias,
compilation_result->input_mapping, resource_var_snapshots, write.type,
compilation_result->input_mapping, resource_vars, write.type,
write.shape, buffer, allocator);
}
*variable_infos[i].var()->tensor() = output_tensor;
variable_infos[i].var()->is_initialized |= write.modified;
output.set_buffer(se::OwningDeviceMemory(), {output_num});
var->is_initialized |= write.modified;
*var->tensor() = output_tensor;
++output_num;
}
return Status::OK();
@ -562,7 +600,7 @@ Status XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.name = std::string(variable.name());
arg.kind = XlaCompiler::Argument::kResource;
arg.resource_kind = XlaResource::kVariable;
if (variable.var()) {
if (variable.var() && variable.var()->is_initialized) {
const Tensor* value = variable.var()->tensor();
arg.type = value->dtype();
arg.shape = value->shape();

View File

@ -81,6 +81,12 @@ class VariableInfo {
bool lock_held_ = false;
};
// Creates a list of updated resource variables.
xla::StatusOr<std::vector<VariableInfo>> GatherVariableInfo(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult& compilation_result,
int missing_ctx_input_prefix);
// Takes a snapshot of the values of resource variable arguments, whose indices
// are specified in `variable_indices` argument. We snapshot tensors that back
// resource variables since concurrent updates may modify the shape, and it is
@ -124,7 +130,7 @@ class XlaComputationLaunchContext {
// objects.
XlaComputationLaunchContext(xla::LocalClient* client,
se::DeviceMemoryAllocator* xla_allocator,
bool allocate_xla_tensors,
int device_ordinal, bool allocate_xla_tensors,
bool use_multiple_streams);
// Builds a XlaCompiler::Argument vector from the arguments to an XlaLaunch
@ -142,10 +148,12 @@ class XlaComputationLaunchContext {
// missing and adjusts input indices accordingly. All elements in kernel's
// input_mapping must be greater than or equal to `missing_ctx_input_prefix`
// (in other words, no inputs actually required by the kernel can be missing).
void PopulateInputs(OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const ResourceVarsSnapshot& variables,
int missing_ctx_input_prefix);
xla::StatusOr<std::vector<xla::ExecutionInput>> PopulateInputs(
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
const std::map<int, const Tensor*>& resource_vars,
int missing_ctx_input_prefix,
const xla::HloInputOutputAliasConfig& input_output_alias);
// Given the XLA output in `output`, populate all outputs of `ctx`. Also
// writes out the resource variable updates.
@ -161,20 +169,16 @@ class XlaComputationLaunchContext {
OpKernelContext* ctx,
const XlaCompiler::CompilationResult* compilation_result,
xla::ScopedShapedBuffer output, int missing_ctx_input_prefix,
absl::Span<VariableInfo> variable_infos,
const xla::HloInputOutputAliasConfig& input_output_alias,
const ResourceVarsSnapshot& resource_var_snapshots);
// Return the argument list. Only valid after PopulateInputs() has been
// called.
const std::vector<xla::ShapedBuffer*>& arguments() const { return arg_ptrs_; }
const std::map<int, const Tensor*>& resource_vars);
private:
xla::LocalClient* client_;
se::DeviceMemoryAllocator* xla_allocator_;
bool allocate_xla_tensors_;
bool use_multiple_streams_;
std::deque<xla::ShapedBuffer> arg_buffers_;
std::vector<xla::ShapedBuffer*> arg_ptrs_;
int device_ordinal_;
};
// A simple TensorBuffer implementation that allows us to create Tensors that

View File

@ -74,7 +74,6 @@ We have several choices on how to lower the host-side part from LHLO:
* (Pro) easy to implement library calls (cuDNN, cuBLAS, cuFFT, etc), as
TFRT ops are interpreted by C++ code.
* (Con) host side is under development and not tested.
* (Con) the JAX integration isnt clear from a runtime point of view
* Jitted CPU code
* (Pro) great lower-ability. Create a few loops and conditions and it's
done.
@ -84,8 +83,7 @@ We have several choices on how to lower the host-side part from LHLO:
dynamic loading, etc).
* Existing (interpreting) XLA runtime
Tentative conclusion: Use jitted CPU code during the transition, and optionally
adopt TFRT in the end.
Decision: adopt TFRT, but also support jitting CPU code in TFRT.
## Migrating Device LLVM IR (Task 3)
@ -114,7 +112,7 @@ end state of each XLA op:
* (Cost) Will be throw-away work if we want to ultimately migrate to
Standard.
* (Benefit) It is easy and mechanical. Can be done in a short period.
* (Benefit) It doesn't benefit more compared to a).
* (Benefit) It doesn't benefit more compared to (1).
1. Refactor old emitters to be like LHLO -> MLIR GPU + Standard + Loops:
* (Cost) Lifting existing emitters to Standard introduces some challenges.
Pointers and GEPs need to be converted to MemRefs and SubViews. Ensuring
@ -134,6 +132,19 @@ end state of each XLA op:
* (Benefit) unified stack; community support; portability; more
optimization potentials.
Conclusions:
* Don't go for (2). (1) or (3) are just better than (2). (2) costs more than
(1), since it requires a lot of mechanical refactoring. With (1) we can
still achieve the goal of enabling XLA to pick up MLIR emitters. This is by
doing LHLO -> LLVM IR -> run legacy device emitters.
* ElementalIrEmitter ops go for (4), but not incrementally. There is no way to
do it op by op, because all elementally-emitted ops are connected into the
same graph. This work can also serve as a unification point of several
on-going forces (xla/service/mlir\_gpu, the kernel generator, Linalg).
* All other ops go for (1). As a stretch goal, they might be migrated to (3)
or (4).
## Prioritization
While all three tasks mentioned above are parallelizable, under limited
@ -210,26 +221,19 @@ The exact profiling can't be easily done for MLIR-generated ops, since:
### Step 3: (Task 2) Migrating Thunks
This step migrates all host ops and library calls. This step will eliminate most
of the thunks and produce serializable MLIR instead.
There are roughly three kinds of thunks:
As a note, there are roughly three kinds of thunks:
* KernelThunk, which launches a kernel.
* Control flow thunks, which has host control flow logic (conditional, while,
for, sequence) and launch body kernels.
* Library thunks: cuDNN, cuBLAS, cuFFT, NCCL, etc.
The **bottom line** is to:
The plan is:
* Make Thunks (de)serializable.
* Help improve TFRT to a state where it can support these semantics.
* As the state improves, migrate individual thunks incrementally.
* Create a Thunk dialect that provides (de)serialize logic for all existing
C++-based Thunks.
* Change emitters to emit a graph of Thunk dialect.
**Optionally**, we can relieve some thunks from C++ implementation. KernelThunk
can lower to the GPU LaunchKernelOp. Control flow thunks can leverage the CFG
Dialect for loops and conditions, combined with LaunchKernelOp. This optional
step requires profiling and stream support.
These action items are only partially ordered. The actual execution order /
engineering parallelism is to be evaluated as it goes.
### Step 4: (Task 3) Migrated ElementalIrEmitter

View File

@ -106,6 +106,25 @@ gentbl(
td_srcs = [":hlo_ops_td_files"],
)
gentbl(
name = "hlo_ops_pattern_gen",
strip_include_prefix = "include",
tbl_outs = [
(
"-gen-rewriters",
"include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "lib/Dialect/mhlo/IR/hlo_patterns.td",
td_srcs = [
":hlo_ops_td_files",
"@llvm-project//mlir:StdOpsTdFiles",
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeBase.td",
"@llvm-project//mlir:include/mlir/Dialect/Shape/IR/ShapeOps.td",
],
)
gentbl(
name = "lhlo_ops_inc_gen",
strip_include_prefix = "include",
@ -203,6 +222,7 @@ cc_library(
],
includes = ["include"],
deps = [
"hlo_ops_pattern_gen",
":canonicalize_inc_gen",
":chlo_ops_inc_gen",
":convert_op_folder",
@ -548,6 +568,26 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "legalize_gather_to_torch_index_select",
srcs = ["lib/Dialect/mhlo/transforms/legalize_gather_to_torch_index_select.cc"],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
"include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h",
],
deps = [
":hlo",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:Transforms",
],
alwayslink = 1,
)
cc_library(
name = "legalize_tanh_to_approximation",
srcs = ["lib/Dialect/mhlo/transforms/legalize_tanh_to_approximation.cc"],
@ -590,6 +630,7 @@ cc_library(
"lib/Dialect/mhlo/transforms/generated_lower_complex.inc",
"lib/Dialect/mhlo/transforms/lower_complex.cc",
"lib/Dialect/mhlo/transforms/lower_general_dot.cc",
"lib/Dialect/mhlo/transforms/optimize_mhlo.cc",
],
hdrs = [
"include/mlir-hlo/Dialect/mhlo/transforms/passes.h",
@ -649,7 +690,9 @@ cc_library(
deps = [
":hlo",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
],
)
@ -661,6 +704,7 @@ cc_library(
"lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc",
"lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc",
"lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc",
"lib/Dialect/mhlo/transforms/optimize_mhlo_pass.cc",
"lib/Dialect/mhlo/transforms/test_infer_shaped_type_pass.cc",
"lib/Dialect/mhlo/transforms/unfuse_batch_norm_pass.cc",
],
@ -678,6 +722,7 @@ cc_library(
"@llvm-project//mlir:LLVMDialect",
"@llvm-project//mlir:LLVMTransforms",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:SCFDialect",
"@llvm-project//mlir:Shape",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Transforms",
@ -695,6 +740,7 @@ cc_library(
":hlo_dialect_registration",
":hlo_legalize_to_lhlo",
":legalize_control_flow",
":legalize_gather_to_torch_index_select",
":legalize_tanh_to_approximation",
":legalize_to_linalg",
":legalize_to_standard",

View File

@ -21,9 +21,9 @@ limitations under the License.
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td"
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/infer_fusibility_op_interface.td"
def HLO_Dialect : Dialect {
let name = "mhlo";

View File

@ -38,6 +38,13 @@ void PopulateGeneralDotOpLoweringPatterns(OwningRewritePatternList *patterns,
void PopulateComplexLoweringPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
void PopulateOptimizeMHLOPatterns(MLIRContext *context,
OwningRewritePatternList *patterns);
// Rewrite patterns for gather to equivalent torch index select legalization.
void PopulateGatherToTorchIndexSelectPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns);
void PopulateMhloToStdPatterns(OwningRewritePatternList *patterns,
MLIRContext *ctx);

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
@ -59,6 +60,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
namespace mlir {
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_patterns.cc.inc"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc"
namespace mhlo {
@ -744,7 +746,8 @@ class DynamicBroadcastInDimOpNotActuallyDynamic
void DynamicBroadcastInDimOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic>(context);
results.insert<DynamicBroadcastInDimOpNotActuallyDynamic,
DynamicBroadcastToOwnShape>(context);
}
//===----------------------------------------------------------------------===//
@ -1465,7 +1468,7 @@ static LogicalResult Verify(PadOp op) {
static LogicalResult Verify(ReshapeOp op) {
// If the operand type is dynamically shaped there is nothing to verify.
auto operand_ty = op.operand().getType().cast<RankedTensorType>();
auto operand_ty = op.operand().getType().dyn_cast<RankedTensorType>();
if (!operand_ty || !operand_ty.hasStaticShape()) return success();
// If the operand type is statically shaped (not required) the number of

View File

@ -0,0 +1,29 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Canonicalization patterns for the MHLO dialect.
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
def EqualBinaryOperands : Constraint<CPred<"$0 == $1">>;
// Canonicalization patterns.
def DynamicBroadcastToOwnShape : Pat<
(HLO_DynamicBroadcastInDimOp:$op $arg0,
(Shape_ToExtentTensorOp (Shape_ShapeOfOp $arg1)), $attr),
(replaceWithValue $arg0), [(EqualBinaryOperands $arg0, $arg1)]>;

View File

@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/OperationSupport.h" // from @llvm-project
@ -22,6 +24,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h"
namespace mlir {
@ -74,10 +77,6 @@ struct ConvertTrivialNonBroadcastBinaryOp : public OpRewritePattern<ChloOpTy> {
// - Legal combinations of degenerate (1-dim) implicit broadcasting.
// The restriction on broadcast_dims derives from the definition of the
// `shape.broadcast` op, which only supports prefix-padding.
//
// It may be possible to expand this pattern to operate on unranked tensors in
// the future by emitting more code to dynamically differentiate based on rank.
// Whether that is of any practical benefit remains to be seen.
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
struct ConvertRankedDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
@ -160,6 +159,68 @@ struct ConvertRankedDynamicBroadcastBinaryOp
}
};
// Converts a broadcasting binary operation with a scalar operand and an
// unranked operand to a ranked broadcasting operation by dynamically reshaping
// the unranked operand to a 1D tensor. This will always be safe because
// broadcasting from a scalar to another shape always works.
template <typename ChloOpTy, typename HloOpTy>
struct ConvertUnrankedScalarDynamicBroadcastBinaryOp
: public OpRewritePattern<ChloOpTy> {
using OpRewritePattern<ChloOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(ChloOpTy op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value lhs = op.lhs();
Value rhs = op.rhs();
auto lhs_ranked_type = lhs.getType().dyn_cast<RankedTensorType>();
auto lhs_unranked_type = lhs.getType().dyn_cast<UnrankedTensorType>();
auto rhs_ranked_type = rhs.getType().dyn_cast<RankedTensorType>();
auto rhs_unranked_type = rhs.getType().dyn_cast<UnrankedTensorType>();
bool lhs_is_scalar = lhs_ranked_type &&
lhs_ranked_type.getShape().empty() &&
rhs_unranked_type;
bool rhs_is_scalar = rhs_ranked_type &&
rhs_ranked_type.getShape().empty() &&
lhs_unranked_type;
// Only support the case where exactly one operand is scalar and the other
// is unranked. Other patterns in this file will create more efficient
// lowerings for cases where both ranks are known or will handle the more
// generic case of both inputs being unranked.
if (!(lhs_is_scalar ^ rhs_is_scalar)) return failure();
auto result_type = op.getResult().getType().template dyn_cast<TensorType>();
// Reshape the non-scalar value into a dynamically sized, rank-1 tensor
Value shape =
rewriter.create<shape::ShapeOfOp>(loc, lhs_is_scalar ? rhs : lhs);
Value num_elements = rewriter.create<shape::NumElementsOp>(loc, shape);
Value size = rewriter.create<shape::SizeToIndexOp>(loc, num_elements);
Value size_tensor = rewriter.create<TensorFromElementsOp>(loc, size);
Value reshaped = rewriter.create<mhlo::DynamicReshapeOp>(
loc, RankedTensorType::get({-1}, result_type.getElementType()),
lhs_is_scalar ? rhs : lhs, size_tensor);
// Create a new ranked Chlo op that will be further lowered by other
// patterns into Mhlo.
SmallVector<Value, 2> operands{lhs_is_scalar ? lhs : reshaped,
rhs_is_scalar ? rhs : reshaped};
Value computed = rewriter.create<ChloOpTy>(
loc, SmallVector<Type, 1>{reshaped.getType()}, operands, op.getAttrs());
// Reshape the result back into an unranked tensor.
Value shape_tensor = rewriter.create<shape::ToExtentTensorOp>(
loc, RankedTensorType::get({-1}, rewriter.getIndexType()), shape);
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_type,
computed, shape_tensor);
return success();
}
};
template <typename ChloOpTy, typename HloOpTy, typename Adaptor>
void PopulateForBinaryOp(MLIRContext *context,
OwningRewritePatternList *patterns) {
@ -169,6 +230,9 @@ void PopulateForBinaryOp(MLIRContext *context,
patterns->insert<
ConvertRankedDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy, Adaptor>>(
context, 5);
patterns->insert<
ConvertUnrankedScalarDynamicBroadcastBinaryOp<ChloOpTy, HloOpTy>>(
context);
}
template <typename FromOpTy, typename ToOpTy>

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project
#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
@ -37,6 +38,7 @@ struct TestChloLegalizeToHloPass
// The conversion uses helpers from the Standard dialect.
conversionTarget.addLegalDialect<mlir::StandardOpsDialect>();
conversionTarget.addLegalDialect<mlir::shape::ShapeDialect>();
conversionTarget.addLegalDialect<mlir::scf::SCFDialect>();
PopulateLegalizeChloToHloPatterns(&getContext(), &conversionPatterns);

View File

@ -42,9 +42,6 @@ namespace {
template <typename T>
using BaseOpConversion = BufferAssignmentOpConversionPattern<T>;
using StdReturnOpConverter =
detail::BufferAssignmentReturnOpConverter<mlir::ReturnOp, mlir::ReturnOp,
lmhlo::CopyOp, true>;
Value InsertDynamicAllocAndDealloc(Location loc, Value result,
Value shape_operand,
@ -272,27 +269,21 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
// Copy over the operations inside the region.
rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end());
// Create new block arguments with correct type.
// Convert the region signature to memref and add extra result.
auto& entry_block = new_op.body().front();
int original_arg_count = entry_block.getNumArguments();
for (int i = 0; i < original_arg_count; ++i) {
auto old_arg = entry_block.getArgument(i);
auto old_type = old_arg.getType().cast<TensorType>();
TypeConverter::SignatureConversion sig_conversion(
entry_block.getNumArguments() + 1);
for (auto arg : entry_block.getArguments()) {
auto old_type = arg.getType().cast<TensorType>();
auto new_type =
MemRefType::get(old_type.getShape(), old_type.getElementType());
auto new_arg = entry_block.addArgument(new_type);
rewriter.replaceUsesOfBlockArgument(old_arg, new_arg);
sig_conversion.addInputs(arg.getArgNumber(), new_type);
}
// Add an argument for the result.
entry_block.addArgument(
entry_block.getArgument(original_arg_count).getType());
// Remove the old arguments.
for (int i = original_arg_count - 1; i >= 0; --i) {
entry_block.eraseArgument(i);
}
// Insert terminator at the end.
rewriter.setInsertionPointToEnd(&entry_block);
rewriter.create<lmhlo::TerminatorOp>(loc);
auto return_op = cast<mhlo::ReturnOp>(entry_block.getTerminator());
auto result_type = return_op.results().front().getType().cast<TensorType>();
sig_conversion.addInputs({MemRefType::get(result_type.getShape(),
result_type.getElementType())});
rewriter.applySignatureConversion(&new_op.body(), sig_conversion);
rewriter.replaceOp(op, ArrayRef<Value>(buffer_args).slice(operands.size()));
@ -300,6 +291,12 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion<mhlo::ReduceOp> {
}
};
// Legalize mhlo.return to a lmhlo.copy and lmhlo.terminator. This functionality
// is provided by mlir buffer assignment, so use the pattern from there.
// TODO(DFKI): Move this out of detail.
using HloToLhloReturnOpConverter = detail::BufferAssignmentReturnOpConverter<
mhlo::ReturnOp, lmhlo::TerminatorOp, lmhlo::CopyOp, false>;
class HloToLhloTensorLoadOpConverter
: public BaseOpConversion<mlir::TensorLoadOp> {
public:
@ -312,7 +309,6 @@ class HloToLhloTensorLoadOpConverter
}
};
// TODO(b/137624192): Rewrite into a copy and elide copy if possible.
class HloToLhloTensorStoreOpConverter
: public BaseOpConversion<mlir::TensorStoreOp> {
public:
@ -506,6 +502,7 @@ void populateHLOToLHLOConversionPattern(
HloToLhloOpConverter<mhlo::SubOp>,
HloToLhloOpConverter<mhlo::TanhOp>,
HloToLhloReduceOpConverter,
HloToLhloReturnOpConverter,
HloToLhloTensorLoadOpConverter,
HloToLhloTensorStoreOpConverter
>(context, bufferAssignment, converter);

View File

@ -0,0 +1,152 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/memory/memory.h"
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
namespace mlir {
namespace mhlo {
namespace {
struct GatherIsTorchIndexSelect : public OpRewritePattern<GatherOp> {
using OpRewritePattern<GatherOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter &rewriter) const override {
auto start_indices = gather.start_indices();
auto start_indices_ty = start_indices.getType().cast<ShapedType>();
if (!start_indices_ty.hasRank()) {
return failure();
}
auto operand = gather.operand();
auto operand_ty = operand.getType().cast<ShapedType>();
if (!operand_ty.hasRank()) {
return failure();
}
int64_t index_vector_dim =
std::max<int64_t>(0, start_indices_ty.getRank() - 1);
// We can use torch_index_select if the last dimension represents the
// gather indices.
auto dimension_numbers = gather.dimension_numbers();
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() !=
index_vector_dim) {
return failure();
}
// Index select only works across a single dimension.
if (!start_indices_ty.getShape().empty() &&
start_indices_ty.getShape().back() != 1) {
return failure();
}
// Only support the default case for start_index_map.
if (dimension_numbers.start_index_map().getType().getRank() != 1 ||
dimension_numbers.start_index_map()
.getValue(0)
.cast<IntegerAttr>()
.getValue() != 0) {
return failure();
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
if (!result_ty) {
return failure();
}
// Offset dimensions should be the defaults.
if (dimension_numbers.offset_dims().getType().getNumElements() !=
result_ty.getRank() - index_vector_dim) {
return failure();
}
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
if ((it.index() + index_vector_dim) != it.value()) {
return failure();
}
}
for (auto it : llvm::enumerate(gather.slice_sizes().getIntValues())) {
// First shape value must be 1.
if (it.index() == 0) {
if (it.value().getSExtValue() != 1) {
return failure();
}
continue;
}
// The gather needs to index the entire slice for each other dimension.
if (it.value().getSExtValue() != operand_ty.getDimSize(it.index())) {
return failure();
}
}
llvm::SmallVector<int64_t, 4> index_select_shape =
llvm::to_vector<4>(start_indices_ty.getShape());
for (auto dim : operand_ty.getShape().drop_front()) {
index_select_shape.push_back(dim);
}
if (!dimension_numbers.collapsed_slice_dims().getType().hasRank() ||
dimension_numbers.collapsed_slice_dims().getType().getNumElements() !=
1 ||
dimension_numbers.collapsed_slice_dims().getValue<int64_t>({0}) != 0) {
return failure();
}
auto torch_index_select = rewriter.create<TorchIndexSelectOp>(
gather.getLoc(),
RankedTensorType::get(index_select_shape, operand_ty.getElementType()),
operand, gather.start_indices(), rewriter.getI64IntegerAttr(0),
rewriter.getI64IntegerAttr(0));
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(),
torch_index_select);
return success();
}
};
struct LegalizeGatherToTorchIndexSelect
: public PassWrapper<LegalizeGatherToTorchIndexSelect, FunctionPass> {
/// Perform the lowering of standard dialect operations to approximations.
void runOnFunction() override {
OwningRewritePatternList patterns;
PopulateGatherToTorchIndexSelectPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
};
} // namespace
void PopulateGatherToTorchIndexSelectPatterns(
mlir::MLIRContext *context, OwningRewritePatternList *patterns) {
patterns->insert<GatherIsTorchIndexSelect>(context);
}
static PassRegistration<LegalizeGatherToTorchIndexSelect> legalize_hlo_pass(
"mhlo-legalize-gather-to-torch-index-select",
"Legalizes gathers to a torch index select.");
} // namespace mhlo
} // namespace mlir

View File

@ -298,8 +298,8 @@ class DataMovementOpConverter : public OpConversionPattern<OpTy> {
auto nloops = resultType.getRank();
auto loc = op.getLoc();
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*inputCount=*/1,
/*outputCount=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
loc, isLHLO ? ArrayRef<Type>{} : resultType, args, /*argsIn=*/1,
/*argsOut=*/1, indexing_maps, GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());
});
@ -420,7 +420,7 @@ class LhloBroadcastInDimConverter
rewriter.create<LoadOp>(loc, operand, llvm::makeArrayRef({zero}));
rewriter.create<linalg::GenericOp>(
loc, llvm::None, llvm::makeArrayRef(operand_adaptor.output()),
/*inputCount=*/0, /*outputCount=*/1,
/*argsIn=*/0, /*argsOut=*/1,
llvm::makeArrayRef(rewriter.getMultiDimIdentityMap(nloops)),
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
@ -433,7 +433,7 @@ class LhloBroadcastInDimConverter
rewriter.create<linalg::GenericOp>(
loc, llvm::None,
llvm::makeArrayRef({operand, operand_adaptor.output()}),
/*inputCount=*/1, /*outputCount=*/1, indexing_maps,
/*argsIn=*/1, /*argsOut=*/1, indexing_maps,
GetNParallelLoopsAttrs(nloops),
[&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) {
nestedBuilder.create<linalg::YieldOp>(loc, *args.begin());

View File

@ -133,8 +133,8 @@ struct ReshapeMemRefCastOpConverter
Location loc = op->getLoc();
auto reshape_op = cast<ReshapeMemRefCastOp>(op);
Type dst_type = reshape_op.getResult().getType();
auto element_type = dst_type.cast<ShapedType>().getElementType();
auto dst_type = reshape_op.getResult().getType().cast<BaseMemRefType>();
auto element_type = dst_type.getElementType();
auto shape = reshape_op.shape();
@ -162,18 +162,17 @@ struct ReshapeMemRefCastOpConverter
desc.setAlignedPtr(rewriter, loc, ptrs_n_offset.aligned_ptr);
desc.setOffset(rewriter, loc, ptrs_n_offset.offset);
auto llvmIndexTy = typeConverter.convertType(rewriter.getIndexType())
.cast<LLVM::LLVMType>();
auto llvmIndexTyPtr = llvmIndexTy.getPointerTo();
auto llvm_index_type = typeConverter.getIndexType();
auto llvm_index_ptr_type = llvm_index_type.getPointerTo();
Value stride_carried = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexTy,
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
for (int i = shape_length - 1; i >= 0; --i) {
Value pos = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexTy,
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), i));
Value ptr = rewriter.create<LLVM::GEPOp>(
loc, llvmIndexTyPtr, shape_desc.alignedPtr(rewriter, loc),
loc, llvm_index_ptr_type, shape_desc.alignedPtr(rewriter, loc),
ValueRange{pos});
Value extracted_size = rewriter.create<LLVM::LoadOp>(loc, ptr);
desc.setSize(rewriter, loc, i, extracted_size);
@ -188,7 +187,7 @@ struct ReshapeMemRefCastOpConverter
rewriter.replaceOp(op, {desc});
} else {
Value rank = rewriter.create<LLVM::ConstantOp>(
loc, llvmIndexTy,
loc, llvm_index_type,
rewriter.getIntegerAttr(rewriter.getIndexType(), shape_length));
Value alloca =
typeConverter.promoteOneMemRefDescriptor(loc, desc, rewriter);
@ -199,15 +198,127 @@ struct ReshapeMemRefCastOpConverter
{rank, void_ptr});
rewriter.replaceOp(op, {unranked_desc});
}
} else {
/*
* TODO(pifon, herhut):
* Compute strides with llvm.loop;
* Use UnrankedMemrefDescr::ComputeSize with Alloca;
* Set all the fields using getelementptr.
*/
return failure();
return success();
}
// The shape is a rank-1 tensor with unknown length.
Value result_rank = shape_desc.size(rewriter, loc, 0);
// TODO(herhut): Propely handle address spaces.
unsigned address_space = 0;
auto target_type =
typeConverter
.convertType(UnrankedMemRefType::get(element_type, address_space))
.cast<LLVM::LLVMType>();
// Create the unranked memref descriptor that holds the ranked one. The
// inner descriptor is allocated on stack.
UnrankedMemRefDescriptor target_desc =
UnrankedMemRefDescriptor::undef(rewriter, loc, target_type);
target_desc.setRank(rewriter, loc, result_rank);
SmallVector<Value, 1> sizes;
UnrankedMemRefDescriptor::computeSizes(rewriter, loc, typeConverter,
{target_desc}, sizes);
auto void_ptr_type =
LLVM::LLVMType::getInt8PtrTy(typeConverter.getDialect());
Value ranked_desc_mem = rewriter.create<LLVM::AllocaOp>(
loc, void_ptr_type, sizes.front(), llvm::None);
target_desc.setMemRefDescPtr(rewriter, loc, ranked_desc_mem);
// Fill the fixed parts. For this, we cast to a 0-D memref.
auto zero_d_memref_type = MemRefType::get({}, element_type);
Value as_zero_d = rewriter.create<LLVM::BitcastOp>(
loc,
typeConverter.convertType(zero_d_memref_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space),
ranked_desc_mem);
// Some common constants. Use 32 bit where required by gep struct indexes.
auto int32_type = typeConverter.convertType(rewriter.getI32Type());
Value zero_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(0));
Value zero = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(0));
Value one = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(1));
Value two = rewriter.create<LLVM::ConstantOp>(
loc, int32_type, rewriter.getI32IntegerAttr(2));
// Set base_pointer and aligned pointer.
auto element_ptr_ptr_type = typeConverter.convertType(element_type)
.cast<LLVM::LLVMType>()
.getPointerTo(address_space)
.getPointerTo(address_space);
auto base_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, zero}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.allocated_ptr, base_gep);
auto aligned_gep = rewriter.create<LLVM::GEPOp>(
loc, element_ptr_ptr_type, as_zero_d, ValueRange({zero_index, one}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.aligned_ptr, aligned_gep);
// Set offset.
auto index_ptr_type =
typeConverter.getIndexType().getPointerTo(address_space);
auto offset_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, as_zero_d, ValueRange({zero_index, two}));
rewriter.create<LLVM::StoreOp>(loc, ptrs_n_offset.offset, offset_gep);
// Use the offset pointer as base for further addressing. Copy over the
// new shape and compute strides. For this, we need to create a loop from
// rank - 1 to 0.
Value one_index = rewriter.create<LLVM::ConstantOp>(
loc, typeConverter.getIndexType(), rewriter.getIndexAttr(1));
auto target_shape_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, offset_gep, ValueRange({one}));
auto target_strides_base = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({result_rank}));
auto shape_ptr = shape_desc.alignedPtr(rewriter, loc);
auto result_rank_minus_one =
rewriter.create<LLVM::SubOp>(loc, result_rank, one_index);
Block *init_block = rewriter.getInsertionBlock();
Block *cond_block =
rewriter.splitBlock(init_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToEnd(init_block);
rewriter.create<LLVM::BrOp>(
loc, ValueRange({result_rank_minus_one, one_index}), cond_block);
rewriter.setInsertionPointToStart(cond_block);
auto index_arg = cond_block->addArgument(typeConverter.getIndexType());
auto stride_arg = cond_block->addArgument(typeConverter.getIndexType());
auto pred = rewriter.create<LLVM::ICmpOp>(
loc, LLVM::LLVMType::getInt1Ty(typeConverter.getDialect()),
LLVM::ICmpPredicate::sge, index_arg, zero_index);
Block *body_block =
rewriter.splitBlock(cond_block, rewriter.getInsertionPoint());
rewriter.setInsertionPointToStart(body_block);
// Copy size from shape to descriptor.
auto size_load_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, shape_ptr, ValueRange{index_arg});
auto extracted_size = rewriter.create<LLVM::LoadOp>(loc, size_load_gep);
auto size_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_shape_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, extracted_size, size_store_gep);
// Write stride value and compute next one.
auto stride_store_gep = rewriter.create<LLVM::GEPOp>(
loc, index_ptr_type, target_strides_base, ValueRange({index_arg}));
rewriter.create<LLVM::StoreOp>(loc, stride_arg, stride_store_gep);
auto next_stride =
rewriter.create<LLVM::MulOp>(loc, stride_arg, extracted_size);
// Decrement loop counter and branch back.
auto decrement = rewriter.create<LLVM::SubOp>(loc, index_arg, one_index);
rewriter.create<LLVM::BrOp>(loc, ValueRange({decrement, next_stride}),
cond_block);
Block *remainder =
rewriter.splitBlock(body_block, rewriter.getInsertionPoint());
// Hook up the cond exit to the remainder.
rewriter.setInsertionPointToEnd(cond_block);
rewriter.create<LLVM::CondBrOp>(loc, pred, body_block, ValueRange(),
remainder, ValueRange());
// Reset position to beginning of new remainder block.
rewriter.setInsertionPointToStart(remainder);
rewriter.replaceOp(op, {target_desc});
return success();
}

View File

@ -51,7 +51,7 @@ class LowerComplex : public PassWrapper<LowerComplex, FunctionPass> {
} // end anonymous namespace
namespace mlir {
namespace hlo {
namespace mhlo {
namespace {
#include "tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/generated_lower_complex.inc"
@ -62,14 +62,14 @@ void PopulateComplexLoweringPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
populateWithGenerated(context, patterns);
}
} // end namespace hlo
} // end namespace mhlo
} // end namespace mlir
// Lowers the complex operations that can be represented using other operations.
void LowerComplex::runOnFunction() {
// Add lowering patterns to the list.
OwningRewritePatternList patterns;
mlir::hlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
mlir::mhlo::PopulateComplexLoweringPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}

View File

@ -0,0 +1,187 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file provides optional optimization patterns for mhlo, canonocalizing
// operations to equivalent but potentially more efficient operations.
#include <cstddef>
#include <cstdint>
#include <iterator>
#include <numeric>
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassRegistry.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
using mlir::OwningRewritePatternList;
namespace mlir {
namespace mhlo {
namespace {
// Returns 1D 64-bit dense elements attribute with the given values.
static DenseIntElementsAttr GetI64ElementsAttr(ArrayRef<int64_t> values,
Builder* builder) {
RankedTensorType ty = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, builder->getIntegerType(64));
return DenseIntElementsAttr::get(ty, values);
}
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
class GatherIsSlice : public OpRewritePattern<GatherOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherOp gather,
PatternRewriter& rewriter) const override {
auto dimension_numbers = gather.dimension_numbers();
// Inputs need to be ranked to lower.
if (!gather.operand().getType().cast<ShapedType>().hasRank() ||
!gather.operand().getType().cast<ShapedType>().hasStaticShape() ||
!gather.start_indices().getType().cast<ShapedType>().hasRank() ||
!gather.start_indices().getType().cast<ShapedType>().hasStaticShape()) {
return failure();
}
if (dimension_numbers.index_vector_dim().getValue().getSExtValue() != 0) {
return failure();
}
// TODO(suderman): Handle start index map != {0}.
if (!dimension_numbers.start_index_map() ||
dimension_numbers.start_index_map().getType().getRank() != 1 ||
dimension_numbers.start_index_map().getType().getDimSize(0) != 1 ||
dimension_numbers.start_index_map()
.getValue({0})
.cast<IntegerAttr>()
.getValue() != 0) {
return failure();
}
auto result_ty = gather.getResult().getType().dyn_cast<RankedTensorType>();
// Requires a ranked output.
if (!result_ty) {
return failure();
}
if (dimension_numbers.offset_dims().getType().getNumElements() !=
result_ty.getRank()) {
return failure();
}
for (auto it : llvm::enumerate(dimension_numbers.offset_dims())) {
if (it.index() != it.value()) {
return failure();
}
}
// Verify the gather slice sizes are correct.
if (gather.slice_sizes().getNumElements() !=
gather.operand().getType().cast<ShapedType>().getRank()) {
return failure();
}
// Validate the slice sizes are correct.
if (gather.slice_sizes().getType().cast<ShapedType>().getNumElements() <
result_ty.getShape().size() + 1) {
return failure();
}
for (auto it : llvm::enumerate(result_ty.getShape())) {
if (gather.slice_sizes()
.getValue(it.index() + 1)
.cast<IntegerAttr>()
.getValue() != it.value()) {
return failure();
}
}
auto gather_start_indices = gather.start_indices();
auto gather_start_indices_ty =
gather_start_indices.getType().cast<ShapedType>();
llvm::SmallVector<Value, 4> slice_start_indices;
if (gather_start_indices_ty.getRank() == 0) {
slice_start_indices.push_back(gather_start_indices);
} else if (gather_start_indices_ty.getRank() == 1) {
for (int i = 0; i < gather_start_indices_ty.getDimSize(0); i++) {
auto start = GetI64ElementsAttr({i}, &rewriter);
auto limit = GetI64ElementsAttr({i + 1}, &rewriter);
auto stride = GetI64ElementsAttr({1}, &rewriter);
auto indicesSlice = rewriter.create<SliceOp>(
gather.getLoc(), gather_start_indices, start, limit, stride);
auto reshaped = rewriter.create<ReshapeOp>(
gather.getLoc(),
RankedTensorType::get(
{}, indicesSlice.getType().cast<ShapedType>().getElementType()),
indicesSlice);
slice_start_indices.push_back(reshaped);
}
} else {
return failure();
}
auto sliceSizes = gather.slice_sizes();
auto sliceSizesTy = sliceSizes.getType();
if (sliceSizesTy.getRank() != 1) {
return failure();
}
// Start indices have implicit zeros when not specified. This is because
// Gather occurs similar to slicing where full slices are inferred. Add any
// missing zeros as necessary.
auto zero = rewriter.create<ConstOp>(
gather.getLoc(), rewriter.getZeroAttr(RankedTensorType::get(
{}, gather_start_indices_ty.getElementType())));
while (slice_start_indices.size() < sliceSizesTy.getDimSize(0)) {
slice_start_indices.push_back(zero);
}
SmallVector<int64_t, 5> sliceShape;
for (auto shapeValue : gather.slice_sizes().getIntValues()) {
sliceShape.push_back(shapeValue.getSExtValue());
}
auto sliceTy =
RankedTensorType::get(sliceShape, result_ty.getElementType());
auto slice = rewriter.create<DynamicSliceOp>(
gather.getLoc(), sliceTy, gather.operand(), slice_start_indices,
gather.slice_sizes());
rewriter.replaceOpWithNewOp<ReshapeOp>(gather, gather.getType(), slice);
return success();
}
};
} // end anonymous namespace
void PopulateOptimizeMHLOPatterns(MLIRContext* context,
OwningRewritePatternList* patterns) {
patterns->insert<GatherIsSlice>(context);
}
} // end namespace mhlo
} // end namespace mlir

View File

@ -0,0 +1,49 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
using mlir::FunctionPass;
using mlir::PassRegistration;
using mlir::PassWrapper;
namespace {
class OptimizeMhlo : public PassWrapper<OptimizeMhlo, FunctionPass> {
public:
explicit OptimizeMhlo() : PassWrapper<OptimizeMhlo, FunctionPass>() {}
/// Performs the lowering to MHLO dialect.
void runOnFunction() override;
};
} // end anonymous namespace
// Lowers the complex operations that can be represented using other operations.
void OptimizeMhlo::runOnFunction() {
// Add lowering patterns to the list.
mlir::OwningRewritePatternList patterns;
mlir::mhlo::PopulateOptimizeMHLOPatterns(&getContext(), &patterns);
applyPatternsAndFoldGreedily(getFunction(), patterns);
}
static PassRegistration<OptimizeMhlo> pass("mhlo-test-optimize",
"Run optional HLO optimizations.");

View File

@ -365,6 +365,16 @@ func @dynamic_broadcast_in_dim_op_not_actually_dynamic(%arg0: tensor<4xf32>, %ar
return %0 : tensor<5x4xf32>
}
// CHECK-LABEL: func @dynamic_broadcast_in_dim_to_same_shape
func @dynamic_broadcast_in_dim_to_same_shape(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// CHECK-SAME: %[[ARG:.*]]: tensor<?xf32>
%0 = shape.shape_of %arg0 : tensor<?xf32>
%1 = shape.to_extent_tensor %0 : tensor<1xindex>
%2 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %1) { broadcast_dimensions = dense<0> : tensor<1xi64> } : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: return %[[ARG]] : tensor<?xf32>
return %2 : tensor<?xf32>
}
// CHECK-LABEL: func @broadcast_in_dim_constant_fold_0d
func @broadcast_in_dim_constant_fold_0d() -> tensor<1x64x224x224xf32> {
%cst = mhlo.constant dense<0.000000e+00> : tensor<f32>

View File

@ -8,7 +8,7 @@
func @broadcast_add(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<1xindex> {
// CHECK-DAG: %[[ARG0_S:.+]] = shape.shape_of %[[ARG0]]
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-DAG: %[[BCAST_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]]
// CHECK: return %[[EXTENTS]]
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>

View File

@ -18,7 +18,7 @@ func @dynamicBroadcast(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-DAG: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-DAG: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>}
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>}
@ -39,7 +39,7 @@ func @dynamicBroadcastComplex(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK-NEXT: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK-NEXT: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK-NEXT: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK-NEXT: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
@ -60,7 +60,7 @@ func @dynamicBroadcastCompare(%arg0: tensor<?xf32>, %arg1: tensor<?x?xf32>) -> t
// CHECK-DAG: %[[ARG1_S:.+]] = shape.shape_of %[[ARG1]]
// CHECK: %[[WITNESS:.+]] = shape.cstr_broadcastable %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[FINAL_RESULT:.+]] = shape.assuming %[[WITNESS]]
// CHECK: %[[RESULT_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]])
// CHECK: %[[RESULT_S:.+]] = shape.broadcast %[[ARG0_S]], %[[ARG1_S]]
// CHECK: %[[RESULT_EXTENTS:.+]] = shape.to_extent_tensor %[[RESULT_S]]
// CHECK-DAG: %[[ARG0_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG0]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
// CHECK-DAG: %[[ARG1_B:.+]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG1]], %[[RESULT_EXTENTS]]) {broadcast_dimensions = dense<[0, 1]> : tensor<2xi64>} : (tensor<?x?xf32>, tensor<2xindex>) -> tensor<?x?xf32>
@ -237,3 +237,77 @@ func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x
%0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1>
return %0 : tensor<4xi1>
}
// -----
func @addScalarUnranked(%arg0: tensor<f32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<f32>, tensor<*xf32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addScalarUnranked(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<f32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*xf32>
// CHECK-SAME: ) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_1]]
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_1]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<f32>
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_0]], %[[SHAPE_RESHAPED]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[SCALAR_SHAPE:.*]] = shape.const_shape []
// CHECK: %[[BROADCASTED_SHAPE:.*]] = shape.broadcast %[[SCALAR_SHAPE]], %[[SHAPE_RESHAPED]]
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[BROADCASTED_SHAPE]] : tensor<1xindex>
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_0]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_1]] : tensor<?xindex>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }
// -----
func @addUnrankedScalar(%arg0: tensor<*xf32>, %arg1: tensor<f32>) -> tensor<*xf32> {
%0 = chlo.broadcast_add %arg0, %arg1 : (tensor<*xf32>, tensor<f32>)
-> tensor<*xf32>
return %0 : tensor<*xf32>
}
// CHECK-LABEL: func @addUnrankedScalar(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*xf32>,
// CHECK-SAME: %[[ARG_1:.*]]: tensor<f32>) -> tensor<*xf32> {
// First handle the dynamic reshaping of the unranked operand
// to a 1D tensor.
// CHECK: %[[SHAPE_0:.*]] = shape.shape_of %[[ARG_0]] : tensor<*xf32>
// CHECK: %[[NUM_ELEMENTS:.*]] = shape.num_elements %[[SHAPE_0]]
// CHECK: %[[SIZE:.*]] = shape.size_to_index %[[NUM_ELEMENTS]]
// CHECK: %[[SIZE_TENSOR:.*]] = tensor_from_elements(%[[SIZE]]) : tensor<1xindex>
// CHECK: %[[RESHAPED:.*]] = "mhlo.dynamic_reshape"(%[[ARG_0]], %[[SIZE_TENSOR]]) : (tensor<*xf32>, tensor<1xindex>) -> tensor<?xf32>
// The assuming region is part of the second stage of lowering
// with ranked broadcasting logic.
// CHECK: %[[SHAPE_RESHAPED:.*]] = shape.shape_of %[[RESHAPED]] : tensor<?xf32>
// CHECK: %[[SHAPE_1:.*]] = shape.shape_of %[[ARG_1]] : tensor<f32>
// CHECK: %[[WITNESS:.*]] = shape.cstr_broadcastable %[[SHAPE_RESHAPED]], %[[SHAPE_1]]
// CHECK: %[[ASSUMING_RESULT:.*]] = shape.assuming %[[WITNESS]] -> (tensor<?xf32>) {
// CHECK: %[[SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_RESHAPED]] : tensor<1xindex>
// CHECK: %[[BROADCASTED_LHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[RESHAPED]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RHS:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[ARG_1]], %[[SHAPE_TENSOR]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>, tensor<1xindex>) -> tensor<?xf32>
// CHECK: %[[BROADCASTED_RESULT:.*]] = mhlo.add %[[BROADCASTED_LHS]], %[[BROADCASTED_RHS]] : tensor<?xf32>
// CHECK: shape.assuming_yield %[[BROADCASTED_RESULT]] : tensor<?xf32>
// CHECK: }
// As part of the unranked logic, the result is reshaped back
// to an unranked tensor.
// CHECK: %[[PROPER_SHAPE_TENSOR:.*]] = shape.to_extent_tensor %[[SHAPE_0]] : tensor<?xindex>
// CHECK: %[[RESHAPED_RESULT:.*]] = "mhlo.dynamic_reshape"(%[[VAL_19:.*]], %[[PROPER_SHAPE_TENSOR]]) : (tensor<?xf32>, tensor<?xindex>) -> tensor<*xf32>
// CHECK: return %[[RESHAPED_RESULT]] : tensor<*xf32>
// CHECK: }

View File

@ -0,0 +1,41 @@
// RUN: mlir-hlo-opt -mhlo-legalize-gather-to-torch-index-select %s -o - | FileCheck %s
// CHECK-LABEL: @gather_to_index_select
func @gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x4xf32> {
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
// CHECK-SAME: batch_dims = 0 : i64,
// CHECK-SAME: dim = 0 : i64
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x1x4xf32>
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x4xf32>
// CHECK: return [[RES]]
return %0 : tensor<1x3x4xf32>
}
// CHECK-LABEL: @scalar_gather_to_index_select
func @scalar_gather_to_index_select(%arg0 : tensor<5x4xf32>, %arg1 : tensor<i32>) -> tensor<1x4xf32> {
// CHECK: [[TIS:%.+]] = "mhlo.torch_index_select"(%arg0, %arg1) {
// CHECK-SAME: batch_dims = 0 : i64,
// CHECK-SAME: dim = 0 : i64
// CHECK-SAME: } : (tensor<5x4xf32>, tensor<i32>) -> tensor<4xf32>
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[TIS]])
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 0 : i64, offset_dims = dense<[0, 1]> : tensor<2xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<i32>) -> tensor<1x4xf32>
// CHECK: return [[RES]]
return %0 : tensor<1x4xf32>
}
// CHECK-LABEL: @gather_no_lowering_subslice
func @gather_no_lowering_subslice(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x1xi32>) -> tensor<1x3x3xf32> {
// CHECK: "mhlo.gather"
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 3]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x1xi32>) -> tensor<1x3x3xf32>
return %0 : tensor<1x3x3xf32>
}
// CHECK-LABEL: @gather_no_lowering_multidim
func @gather_no_lowering_multidim(%arg0 : tensor<5x4xf32>, %arg1 : tensor<1x3x2xi32>) -> tensor<1x3x4xf32> {
// CHECK: "mhlo.gather"
%0 = "mhlo.gather"(%arg0, %arg1) {dimension_numbers = {collapsed_slice_dims = dense<0> : tensor<1xi64>, index_vector_dim = 2 : i64, offset_dims = dense<2> : tensor<1xi64>, start_index_map = dense<0> : tensor<1xi64>}, indices_are_sorted = false, slice_sizes = dense<[1, 4]> : tensor<2xi64>} : (tensor<5x4xf32>, tensor<1x3x2xi32>) -> tensor<1x3x4xf32>
return %0 : tensor<1x3x4xf32>
}

View File

@ -487,3 +487,26 @@ func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor
} : (tensor<2x2x3x4xf32>, tensor<3x5x5x3xf32>) -> tensor<3x5x5x4xf32>
return %out : tensor<3x5x5x4xf32>
}
// -----
// BOTH-LABEL: func @reduce
func @reduce(%arg0: tensor<1x8xf32>, %arg1: tensor<f32>) -> tensor<1xf32> {
// BOTH: %[[OUT:.*]] = alloc() : memref<1xf32>
// BOTH: "lmhlo.reduce"(%{{.+}}, %{{.+}}, %[[OUT]]) ( {
// BOTH: ^bb0(%[[ARG1:.*]]: memref<f32>, %[[ARG2:.*]]: memref<f32>,
// BOTH-SAME: %[[ARG3:.*]]: memref<f32>):
// BOTH: %[[TMP:.*]] = alloc() : memref<f32>
// BOTH: "lmhlo.add"(%[[ARG1]], %[[ARG2]], %[[TMP]])
// BOTH: "lmhlo.copy"(%[[TMP]], %[[ARG3]])
// BOTH: "lmhlo.terminator"() : () -> ()
// BOTH: }) {dimensions = dense<1> : tensor<1xi64>}
// BOTH-SAME: : (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
%0 = "mhlo.reduce"(%arg0, %arg1) ( {
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>): // no predecessors
%1 = mhlo.add %arg2, %arg3 : tensor<f32>
"mhlo.return"(%1) : (tensor<f32>) -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (tensor<1x8xf32>, tensor<f32>) -> tensor<1xf32>
return %0 : tensor<1xf32>
}

View File

@ -91,3 +91,25 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>,
dealloc %0 : memref<2x2xf32>
"lmhlo.terminator"() : () -> ()
}
// -----
// CHECK-LABEL: func @reduce
func @reduce(%arg0: memref<1x8xf32>, %arg1: memref<f32>, %arg2: memref<1xf32>) {
%0 = alloc() : memref<1xf32>
"lmhlo.reduce"(%arg0, %arg1, %0) ( {
// CHECK: ^bb0(%[[ARG0:.*]]: memref<f32>, %[[ARG1:.*]]: memref<f32>,
// CHECK-SAME: %[[ARG2:.*]]: memref<f32>)
^bb0(%arg3: memref<f32>, %arg4: memref<f32>, %arg5: memref<f32>):
%1 = alloc() : memref<f32>
// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]])
"lmhlo.add"(%arg3, %arg4, %1)
: (memref<f32>, memref<f32>, memref<f32>) -> ()
// CHECK-NOT; lmhlo.copy
"lmhlo.copy"(%1, %arg5) : (memref<f32>, memref<f32>) -> ()
"lmhlo.terminator"() : () -> ()
}) {dimensions = dense<1> : tensor<1xi64>}
: (memref<1x8xf32>, memref<f32>, memref<1xf32>) -> ()
"lmhlo.copy"(%0, %arg2) : (memref<1xf32>, memref<1xf32>) -> ()
return
}

View File

@ -0,0 +1,64 @@
// RUN: mlir-hlo-opt %s -pass-pipeline='func(mhlo-test-optimize)' | FileCheck %s
// CHECK-LABEL: @gather_is_slice_no_rank
func @gather_is_slice_no_rank(%arg0: tensor<2x1x2xi32>, %arg1: tensor<i64>) -> tensor<1x2xi32> {
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, %arg1, [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"([[SLICE]])
%res = "mhlo.gather"(%arg0, %arg1) {
dimension_numbers = {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
start_index_map = dense<0> : tensor<1xi64>
},
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
} : (tensor<2x1x2xi32>, tensor<i64>) -> tensor<1x2xi32>
// CHECK: return [[RESHAPE]]
return %res : tensor<1x2xi32>
}
// CHECK-LABEL: @gather_is_slice
func @gather_is_slice(%arg0: tensor<2x1x2xi32>, %arg1: tensor<1xi64>) -> tensor<1x2xi32> {
// CHECK: [[CST:%.+]] = mhlo.constant dense<0> : tensor<i64>
// CHECK: [[RESHAPE:%.+]] = "mhlo.reshape"(%arg1)
// CHECK: [[SLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE]], [[CST]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
// CHECK: [[RES:%.+]] = "mhlo.reshape"([[SLICE]])
%res = "mhlo.gather"(%arg0, %arg1) {
dimension_numbers = {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
start_index_map = dense<0> : tensor<1xi64>
},
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
} : (tensor<2x1x2xi32>, tensor<1xi64>) -> tensor<1x2xi32>
// CHECK: return [[RES]]
return %res : tensor<1x2xi32>
}
// CHECK-LABEL: @gather_is_slice_multiple_start_indices
func @gather_is_slice_multiple_start_indices(%arg0: tensor<2x1x2xi32>, %arg1: tensor<2xi64>) -> tensor<1x2xi32> {
// CHECK-DAG: [[CST:%.+]] = mhlo.constant dense<0>
// CHECK-DAG: [[SLICE1:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[RESHAPE1:%.+]] = "mhlo.reshape"([[SLICE1]])
// CHECK-DAG: [[SLICE2:%.+]] = "mhlo.slice"(%arg1) {limit_indices = dense<2> : tensor<1xi64>, start_indices = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>}
// CHECK-DAG: [[RESHAPE2:%.+]] = "mhlo.reshape"([[SLICE2]])
// CHECK-DAG: [[DSLICE:%.+]] = "mhlo.dynamic-slice"(%arg0, [[RESHAPE1]], [[RESHAPE2]], [[CST]]) {slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>}
// CHECK-DAG: [[RES:%.+]] = "mhlo.reshape"([[DSLICE]])
%res = "mhlo.gather"(%arg0, %arg1) {
dimension_numbers = {
collapsed_slice_dims = dense<0> : tensor<1xi64>,
index_vector_dim = 0 : i64,
offset_dims = dense<[0, 1]> : tensor<2xi64>,
start_index_map = dense<0> : tensor<1xi64>
},
slice_sizes = dense<[1, 1, 2]> : tensor<3xi64>
} : (tensor<2x1x2xi32>, tensor<2xi64>) -> tensor<1x2xi32>
// CHECK: return [[RES]]
return %res : tensor<1x2xi32>
}

View File

@ -25,7 +25,6 @@ package_group(
filegroup(
name = "tensorflow_lite_ops_td_files",
srcs = [
"experimental/tfl_hardware_interfaces.td",
"ir/tfl_op_interfaces.td",
"ir/tfl_ops.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
@ -221,18 +220,14 @@ cc_library(
],
deps = [
":tensorflow_lite_ops_inc_gen",
":validators",
"//tensorflow/compiler/mlir/lite/experimental/estimators:cost_estimators",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/lite/schema:schema_fbs",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:LoopLikeInterface",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:SideEffects",
"@llvm-project//mlir:StandardOps",
@ -242,6 +237,28 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "constant_utils",
srcs = [
"utils/constant_utils.cc",
],
hdrs = [
"utils/constant_utils.h",
],
copts = ["-std=c++14"],
deps = [
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:status",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:Support",
],
)
cc_library(
name = "lstm_utils",
srcs = [
@ -273,6 +290,7 @@ cc_library(
deps = [
":tensorflow_lite",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/core:framework",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
@ -338,6 +356,7 @@ cc_library(
"transforms/optimize_functional_ops.cc",
"transforms/prepare_composite_functions_tf.cc",
"transforms/prepare_tf.cc",
"transforms/raise_custom_ops.cc",
"transforms/runtime_verify.cc",
"transforms/split_merged_operands.cc",
"transforms/trim_functions_tf.cc",
@ -349,7 +368,7 @@ cc_library(
"transforms/passes.h",
],
deps = [
":common",
":constant_utils",
":lstm_utils",
":stateful_ops_utils",
":tensorflow_lite",
@ -360,6 +379,7 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/mlir/tensorflow:unroll_batch_matmul_pass",
"//tensorflow/compiler/xla:status",
@ -368,7 +388,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:tensor_list",
"//tensorflow/core/platform:logging",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
@ -399,7 +418,6 @@ cc_library(
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:StandardOps",
@ -433,7 +451,6 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
@ -456,7 +473,6 @@ cc_library(
"//tensorflow/lite/tools/optimize/sparsity:format_converter",
"@com_google_absl//absl/base",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
@ -480,7 +496,6 @@ gentbl(
td_srcs = [
"@llvm-project//mlir:include/mlir/Interfaces/LoopLikeInterface.td",
"//tensorflow/compiler/mlir/lite/quantization:quantization_td_files",
"experimental/tfl_hardware_interfaces.td",
"ir/tfl_op_interfaces.td",
],
)
@ -609,8 +624,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_tensor",
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_dialect_registration",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
@ -620,7 +633,7 @@ cc_library(
"//tensorflow/core/platform:status",
"//tensorflow/lite:schema_fbs_version",
"//tensorflow/lite:string_util",
"//tensorflow/lite/delegates/flex:whitelisted_flex_ops_lib",
"//tensorflow/lite/delegates/flex:allowlisted_flex_ops_lib",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/schema:schema_fbs",
"//tensorflow/lite/tools/versioning",
@ -651,7 +664,6 @@ cc_library(
":flatbuffer_tflite_operator_lib",
":tensorflow_lite",
":tensorflow_lite_dialect_registration",
"//tensorflow/compiler/mlir/lite/quantization:quantization_lib",
"//tensorflow/compiler/mlir/tensorflow:mangling_util",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:statusor",
@ -724,7 +736,6 @@ cc_library(
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:MlirTranslateMain",
@ -858,10 +869,8 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AllPassesAndDialects",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",
"@llvm-project//mlir:Transforms",
],
)

View File

@ -8,9 +8,6 @@ package(
cc_library(
name = "cost_estimators",
textual_hdrs = [
"estimator.h",
"cpu_estimators.h",
"gpu_estimators.h",
"hardware.h",
"arithmetic_count_util.h",
],

View File

@ -15,13 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ARITHMETIC_COUNT_UTIL_H_
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
// For add/mul/div/sub and other broadcastable ops.
class ArithmeticCountUtilHelper {
public:
static bool GetArithmeticCountForBroadcastableOp(mlir::Operation* op,
int64_t* count) {
auto output = op->getResult(0);
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type =
output.getType().dyn_cast_or_null<mlir::RankedTensorType>();
if (!output_type || !output_type.hasStaticShape()) return false;
*count = output_type.getNumElements();
@ -31,7 +35,8 @@ class ArithmeticCountUtilHelper {
static bool GetInputTensorTotalSize(mlir::Operation* op, int64_t* count) {
int64_t total_count = 0;
for (auto input : op->getOperands()) {
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
auto input_type =
input.getType().dyn_cast_or_null<mlir::RankedTensorType>();
if (!input_type || !input_type.hasStaticShape()) {
return false;
}
@ -43,14 +48,16 @@ class ArithmeticCountUtilHelper {
// For conv2d/depthwise_conv/fully_connected ops.
// This algorithm actually comes from TOCO tooling_util.cc
static bool GetArithmeticCountForConvAndFullyconnectedOp(Operation* op,
static bool GetArithmeticCountForConvAndFullyconnectedOp(mlir::Operation* op,
int64_t* count) {
auto weight = op->getOperand(1);
auto weight_type = weight.getType().dyn_cast_or_null<RankedTensorType>();
auto weight_type =
weight.getType().dyn_cast_or_null<mlir::RankedTensorType>();
if (weight_type == nullptr || !weight_type.hasStaticShape()) return false;
auto output = op->getResult(0);
auto output_type = output.getType().dyn_cast_or_null<RankedTensorType>();
auto output_type =
output.getType().dyn_cast_or_null<mlir::RankedTensorType>();
if (output_type == nullptr || !output_type.hasStaticShape()) return false;
int64_t cols = 1;
@ -63,7 +70,8 @@ class ArithmeticCountUtilHelper {
auto bias = op->getOperand(2);
if (bias) {
auto bias_type = bias.getType().dyn_cast_or_null<RankedTensorType>();
auto bias_type =
bias.getType().dyn_cast_or_null<mlir::RankedTensorType>();
if (bias_type && bias_type.hasStaticShape()) {
*count += bias_type.getNumElements();
}

View File

@ -1,149 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_
// CPU
constexpr float kCPUArithmeticUnitCost = 1.0;
// This basically assumes pure load/store. This is just fake data.
constexpr float kCPUCopyUnitCost = 0.5;
constexpr float kCPUDefaultCost = 3.0f;
// Default values.
constexpr float kCPUDefaultFixedValuedCost = 10000.0;
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kCPUArithmeticUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.concatenation
template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.conv_2d
template <>
class TFLiteCostEstimator<Conv2DOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t arithmetic_count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
op, &arithmetic_count)) {
return arithmetic_count * kCPUArithmeticUnitCost;
}
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.depthwise_conv_2d
template <>
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t arithmetic_count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
op, &arithmetic_count)) {
return arithmetic_count * kCPUArithmeticUnitCost;
}
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.fully_connected
template <>
class TFLiteCostEstimator<FullyConnectedOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t arithmetic_count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
op, &arithmetic_count)) {
return arithmetic_count * kCPUArithmeticUnitCost;
}
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kCPUArithmeticUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pack
template <>
class TFLiteCostEstimator<PackOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::CPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kCPUCopyUnitCost * count;
return kCPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_CPU_ESTIMATORS_H_

View File

@ -1,51 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_
#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Operation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/experimental/estimators/hardware.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
template <typename Op, typename TargetHardware>
class TFLiteCostEstimator {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) {
llvm::errs() << "No defined support for op: "
<< op->getName().getStringRef().str();
return false;
}
};
// All ops on CPU are supported.
// TODO(karimnosseir): Only allow TFL ops in the "TFL_OP" param.
template <typename TFL_OP>
class TFLiteCostEstimator<TFL_OP, hardware::CPU> {
public:
// TODO(karimnosseir): Update and use table based method and lookup
// cost from a loadable table ?
static double GetCost(mlir::Operation* op) { return 0.0; }
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_ESTIMATOR_H_

View File

@ -1,543 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
#define TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_
// GPU
constexpr float kGPUArithmeticUnitCost = 0.2;
// The copy can be non-consectutive copy. This is just fake data.
constexpr float kGPUCopyUnitCost = 0.2;
constexpr float kGPUDefaultCost = 1.0f;
// Default values.
constexpr float kGPUDefaultFixedValuedCost = 10000.0;
// tfl.abs
template <>
class TFLiteCostEstimator<AbsOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.add
template <>
class TFLiteCostEstimator<AddOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kGPUArithmeticUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.average_pool_2d
template <>
class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.concatenation
template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kGPUCopyUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.conv_2d
template <>
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t arithmetic_count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
op, &arithmetic_count)) {
return arithmetic_count * kGPUArithmeticUnitCost;
}
return kGPUDefaultFixedValuedCost;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.cos
template <>
class TFLiteCostEstimator<CosOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.depthwise_conv_2d
template <>
class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t arithmetic_count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
op, &arithmetic_count)) {
return arithmetic_count * kGPUArithmeticUnitCost;
}
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.div
template <>
class TFLiteCostEstimator<DivOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.exp
template <>
class TFLiteCostEstimator<ExpOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.fully_connected
template <>
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t arithmetic_count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForConvAndFullyconnectedOp(
op, &arithmetic_count)) {
return arithmetic_count * kGPUArithmeticUnitCost;
}
return kGPUDefaultFixedValuedCost;
}
// TODO(renjieliu): we need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.hard_swish
template <>
class TFLiteCostEstimator<HardSwishOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.log
template <>
class TFLiteCostEstimator<LogOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.logistic
template <>
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.max_pool_2d
template <>
class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mirror_pad
template <>
class TFLiteCostEstimator<MirrorPadOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.maximum
template <>
class TFLiteCostEstimator<MaximumOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.custom
template <>
class TFLiteCostEstimator<CustomOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mean
template <>
class TFLiteCostEstimator<MeanOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieiu): check for constraints.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.minimum
template <>
class TFLiteCostEstimator<MinimumOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetArithmeticCountForBroadcastableOp(op,
&count))
return kGPUArithmeticUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pad
template <>
class TFLiteCostEstimator<PadOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.pow
template <>
class TFLiteCostEstimator<PowOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.prelu
template <>
class TFLiteCostEstimator<PReluOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu
template <>
class TFLiteCostEstimator<ReluOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.relu6
template <>
class TFLiteCostEstimator<Relu6Op, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.reshape
template <>
class TFLiteCostEstimator<ReshapeOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
int64_t count;
if (ArithmeticCountUtilHelper::GetInputTensorTotalSize(op, &count))
return kGPUCopyUnitCost * count;
return kGPUDefaultFixedValuedCost;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.rsqrt
template <>
class TFLiteCostEstimator<RsqrtOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.sin
template <>
class TFLiteCostEstimator<SinOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.slice
template <>
class TFLiteCostEstimator<SliceOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.softmax
template <>
class TFLiteCostEstimator<SoftmaxOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.space_to_depth
template <>
class TFLiteCostEstimator<SpaceToDepthOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.sqrt
template <>
class TFLiteCostEstimator<SqrtOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.square
template <>
class TFLiteCostEstimator<SquareOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.squared_difference
template <>
class TFLiteCostEstimator<SquaredDifferenceOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.strided_slice
template <>
class TFLiteCostEstimator<StridedSliceOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.tanh
template <>
class TFLiteCostEstimator<TanhOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.transpose
template <>
class TFLiteCostEstimator<TransposeOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.transpose_conv
template <>
class TFLiteCostEstimator<TransposeConvOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
#endif // TENSORFLOW_COMPILER_MLIR_LITE_EXPERIMENTAL_ESTIMATORS_GPU_ESTIMATORS_H_

View File

@ -1,76 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// WARNING: This Interface is experimental, DO NOT USE.
// This is the Target Hardware operation interfacea definition file
// for TensorFlow Lite.
#ifndef TFL_TARGET_HARDWARE_OP_INTERFACES
#define TFL_TARGET_HARDWARE_OP_INTERFACES
def TFL_CpuTargetOp : OpInterface<"CpuOpTargetInterface"> {
let description = [{
Interface for ops to run on CPU.
}];
let methods = [
InterfaceMethod<
[{Returns the cost of running this op on CPU.}],
// TODO(karimnosseir): Change to return Cost object instead.
"double", "GetOpCost", (ins "mlir::Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::GetCost(op_to_check);
}]
>,
InterfaceMethod<
[{Returns whether this op can be run on CPU.}],
"bool", "IsSupported", (ins "mlir::Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::CPU>::IsSupported(op_to_check);
}]
>,
];
}
def TFL_GpuTargetOp : OpInterface<"GpuOpTargetInterface"> {
let description = [{
Interface for ops to run on GPU.
}];
let methods = [
InterfaceMethod<
[{Returns the cost of running this op on GPU.}],
// TODO(karimnosseir): Change to return Cost object instead.
"double", "GetOpCost", (ins "Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::GetCost(op_to_check);
}]
>,
InterfaceMethod<
[{Returns whether this op can be run on GPU.}],
"bool", "IsSupported", (ins "Operation*":$op_to_check), [{
// TODO(karimnosseir): Consider changing to another way that doesn't
// rely on template param name.
return TFL::TFLiteCostEstimator<ConcreteOp, TFL::hardware::GPU>::IsSupported(op_to_check);
}]
>,
];
}
#endif // TFL_TARGET_HARDWARE_OP_INTERFACES

View File

@ -149,6 +149,9 @@ static StatusOr<tflite::TensorType> GetTFLiteType(Type type,
if (ftype && ftype.isF32()) {
return tflite::TensorType_COMPLEX64;
}
if (ftype && ftype.isF64()) {
return tflite::TensorType_COMPLEX128;
}
return Status(error::INVALID_ARGUMENT, "Unsupported type");
}
case mlir::StandardTypes::Integer: {
@ -1193,22 +1196,35 @@ Optional<BufferOffset<tflite::SubGraph>> Translator::BuildSubGraph(
if (IsConst(&inst)) continue;
// Fetch operand and result tensor indices.
std::vector<int32_t> operands;
operands.reserve(inst.getNumOperands());
for (auto operand : inst.getOperands()) {
if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor);
else
operands.push_back(tensor_index_map.lookup(operand));
}
std::vector<int32_t> results;
results.reserve(inst.getNumOperands());
for (auto result : inst.getResults()) {
results.push_back(tensor_index_map.lookup(result));
}
Operation* real_inst = &inst;
// CustomTfOp is just a wrapper around a TF op, we export the custom Op
// not the wrapper, so we fetch the op from the region.
if (auto custom_op = dyn_cast<mlir::TFL::CustomTfOp>(inst)) {
// If we have custom op with a region, then use the first op in the
// region, if it exists, otherwise just use params for custom op.
if (!custom_op.body().empty()) {
real_inst = &custom_op.body().front().front();
} else {
module_.emitError(
"Invalid CustomTfOp: Custom TF Op have empty region.");
}
}
std::vector<int32_t> operands;
operands.reserve(real_inst->getNumOperands());
for (auto operand : real_inst->getOperands()) {
if (operand.getType().isa<NoneType>())
operands.push_back(kTfLiteOptionalTensor);
else
operands.push_back(tensor_index_map.lookup(operand));
}
if (auto tfl_operator =
BuildOperator(&inst, operands, results, intermediates))
BuildOperator(real_inst, operands, results, intermediates))
operators.push_back(*tfl_operator);
else
failed_once = true;

View File

@ -19,7 +19,6 @@ limitations under the License.
#define TFL_OP_INTERFACES
include "mlir/IR/OpBase.td"
include "tensorflow/compiler/mlir/lite/experimental/tfl_hardware_interfaces.td"
//===----------------------------------------------------------------------===//
// TFL op interface for stateful operands.

View File

@ -48,14 +48,9 @@ class TensorFlowLiteDialect : public Dialect {
Location loc) override;
};
#include "tensorflow/compiler/mlir/lite/experimental/estimators/estimator.h"
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops_interface.h.inc"
#define GET_OP_CLASSES
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h.inc"
// Include all specializes estimators below this line
#include "tensorflow/compiler/mlir/lite/experimental/estimators/arithmetic_count_util.h"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/cpu_estimators.h"
#include "tensorflow/compiler/mlir/lite/experimental/estimators/gpu_estimators.h"
} // end namespace TFL
} // end namespace mlir

View File

@ -410,10 +410,7 @@ def TFL_ComparisonBinaryBuilder : OpBuilder<
class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TFL_Dialect, mnemonic, !listconcat(traits,
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>,
// All TFL ops are supported on CPU.
DeclareOpInterfaceMethods<TFL_CpuTargetOp>
])> {
[DeclareOpInterfaceMethods<TFL_RuntimeVerification>])> {
// FlatBuffer generation specific information.
// -------------------------------------------
// When generating the FlatBuffer output some operations have
@ -435,8 +432,7 @@ class TFL_Op<string mnemonic, list<OpTrait> traits = []> :
class TFL_ConvOp<string mnemonic, string opSummary, int index> :
TFL_Op<mnemonic, [NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
AffineQuantizedOpInterface, AffineOpCoefficient<index, 1>,
TFL_GpuTargetOp, TFL_SparseOp]> {
AffineQuantizedOpInterface, AffineOpCoefficient<index, 1>, TFL_SparseOp]> {
let summary = opSummary # " operator";
let description = [{
@ -473,8 +469,7 @@ def TFL_AbsOp : TFL_Op<"abs", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Absolute value operator";
let description = [{
@ -495,8 +490,7 @@ def TFL_AddOp : TFL_Op<"add", [
CPred<"TFL::VerifyAddOpShapeConstraints(llvm::cast<AddOp>($_op))">>,
ResultsBroadcastableShape,
NoSideEffect,
Commutative,
TFL_GpuTargetOp]> {
Commutative]> {
let summary = "Addition operator";
let description = [{
@ -573,7 +567,6 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
TFL_TCresVTEtIsSameAsOp<0, 2>>,
AccumulatorUniformScale<3, 1, 2>,
AffineQuantizedOpInterface, AffineOpCoefficient<0, 2>,
TFL_GpuTargetOp,
TFL_SparseOp]> {
let summary = "Transpose convolution operator";
@ -612,8 +605,7 @@ def TFL_TransposeConvOp: TFL_Op<"transpose_conv", [
def TFL_AveragePool2DOp:
TFL_Op<"average_pool_2d",
[NoSideEffect,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Average_pool_2d operator";
let description = [{
@ -713,8 +705,7 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
NoSideEffect,
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultsScale,
TFL_GpuTargetOp
SameOperandsAndResultsScale
]> {
let summary = "Concatenation operator";
@ -861,8 +852,7 @@ def TFL_CosOp: TFL_Op<"cos", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Cosine operator";
let description = [{
@ -916,8 +906,7 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
AffineQuantizedOpInterface,
AffineOpCoefficient<-1, 1>,
TFL_SparseOp,
TFL_GpuTargetOp]> {
TFL_SparseOp]> {
let summary = "Fully connected op";
let arguments = (ins
@ -1070,8 +1059,7 @@ def TFL_LessEqualOp : TFL_Op<"less_equal", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect,
NoQuantizableResult]> {
NoSideEffect]> {
let summary = "Less_equal operator";
let description = [{
@ -1132,8 +1120,7 @@ convolutional neural networks (NIPS 2012)](http://papers.nips.cc/paper/4824-imag
def TFL_GreaterEqualOp : TFL_Op<"greater_equal", [
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult]> {
NoSideEffect]> {
let summary = "Greater_equal operator";
let description = [{
@ -1360,8 +1347,7 @@ def TFL_DivOp : TFL_Op<"div", [
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Division operator";
let description = [{
@ -1427,7 +1413,6 @@ def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup",
def TFL_EqualOp: TFL_Op<"equal", [
Commutative,
NoQuantizableResult,
ResultsBroadcastableShape,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
PredOpTrait<"Operands have same value type", TCopVTEtIsSameAs<0, 1>>]> {
@ -1449,8 +1434,7 @@ def TFL_EqualOp: TFL_Op<"equal", [
}
def TFL_ExpOp: TFL_Op<"exp", [NoSideEffect,
SameOperandsAndResultType,
TFL_GpuTargetOp]> {
SameOperandsAndResultType]> {
let summary = "Natural exponentiation operator";
let description = [{
@ -1634,8 +1618,7 @@ def TFL_GreaterOp : TFL_Op<"greater", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect,
NoQuantizableResult]> {
NoSideEffect]> {
let summary = "Greater operator";
let description = [{
@ -1659,8 +1642,7 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [
NoSideEffect,
SameOperandsAndResultShape,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_GpuTargetOp]> {
TFL_TCresVTEtIsSameAsOp<0, 0>>]> {
let summary = "Hardswish activation function.";
let description = [{
Computes hard-swish activation function
@ -1735,8 +1717,7 @@ def TFL_LessOp : TFL_Op<"less", [
ResultsBroadcastableShape,
BinaryOpSameElementTypeConstraint,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoSideEffect,
NoQuantizableResult]> {
NoSideEffect]> {
let summary = "Less operator";
let description = [{
@ -1812,8 +1793,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [
PredOpTrait<"x and y must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultShape,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> {
FixedOutputRangeInterface]> {
let summary = "Logistic operator";
let description = [{
@ -1841,8 +1821,7 @@ def TFL_LogOp: TFL_Op<"log", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Natural logarithm operator";
let description = [{
@ -1908,8 +1887,7 @@ def TFL_MaxPool2DOp : TFL_Op<"max_pool_2d", [
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
MaxPoolOperandAndResultConstraints,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Max Pool 2D op";
let description = [{
@ -1941,8 +1919,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
NoSideEffect,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
Commutative,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Max operator";
let description = [{
Element-wise max operation.
@ -1965,8 +1942,7 @@ def TFL_MaximumOp : TFL_Op<"maximum", [
def TFL_MeanOp : TFL_Op<"mean", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
TFL_GpuTargetOp]> {
NoSideEffect]> {
let summary = "Mean operator";
let description = [{
@ -2044,8 +2020,7 @@ def TFL_SliceOp : TFL_Op<"slice", [
SameOperandsAndResultsScale,
TFL_OperandHasRankAtMost<0, 4>,
TFL_OperandHasRankAtMost<1, 1>,
TFL_OperandHasRankAtMost<2, 1>,
TFL_GpuTargetOp]> {
TFL_OperandHasRankAtMost<2, 1>]> {
let summary = "Return a slice from 'input'.";
let description = [{
@ -2176,8 +2151,7 @@ def TFL_MinimumOp : TFL_Op<"minimum", [
NoSideEffect,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 5>,
Commutative,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Min operator";
let description = [{
Element-wise min operation.
@ -2203,8 +2177,7 @@ def TFL_MulOp : TFL_Op<"mul", [
Commutative,
BinaryOpSameElementTypeConstraint,
TFL_RuntimePredOpTrait<"Operands do not have valid shapes",
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>,
TFL_GpuTargetOp]> {
CPred<"TFL::VerifyMulOpShapeConstraints(llvm::cast<MulOp>($_op))">>]> {
let summary = "Multiplication operator";
let description = [{
@ -2310,8 +2283,7 @@ def TFL_PadOp : TFL_Op<"pad", [
TFL_OperandRankEquals1DimOfOperand<0, 1>,
PredOpTrait<"the first dim size of the padding argument must be at most 4",
Or<[TFL_OperandIsUnrankedPred<1>,
TFL_OperandDimIsAtMost<1, 0, 4>]>>,
TFL_GpuTargetOp]> {
TFL_OperandDimIsAtMost<1, 0, 4>]>>]> {
let summary = "Padding operator";
let description = [{
@ -2404,8 +2376,7 @@ def TFL_PowOp : TFL_Op<"pow", [
ResultsBroadcastableShape,
NoSideEffect,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Power operator";
let description = [{
@ -2428,7 +2399,6 @@ def TFL_PowOp : TFL_Op<"pow", [
def TFL_PReluOp : TFL_Op<"prelu", [
NoSideEffect,
ResultsBroadcastableShape,
TFL_GpuTargetOp,
TFL_OperandsHaveSameShapesOrBroadcastableShape<[0, 1], 4>,
BinaryOpSameElementTypeConstraint,
PredOpTrait<"input and output must have the same element type",
@ -2470,8 +2440,7 @@ def TFL_ReluOp: TFL_Op<"relu", [
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Relu operator";
let description = [{
@ -2500,8 +2469,7 @@ def TFL_Relu6Op: TFL_Op<"relu6", [
TFL_TCresVTEtIsSameAsOp<0, 0>>,
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Relu6 operator";
let description = [{
@ -2555,7 +2523,7 @@ def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [
}
def TFL_ReshapeOp: TFL_Op<"reshape", [
NoSideEffect, SameOperandsAndResultsScale, TFL_GpuTargetOp]> {
NoSideEffect, SameOperandsAndResultsScale]> {
let summary = "Reshape operator";
let description = [{
@ -2610,8 +2578,7 @@ slice `i`, with the first `seq_lengths[i]` slices along dimension
def TFL_RsqrtOp: TFL_Op<"rsqrt", [NoSideEffect,
SameOperandsAndResultType,
SameOperandsAndResultShape,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Reciprocal of square root operator";
let description = [{
@ -2706,6 +2673,7 @@ def TFL_ReverseV2Op: TFL_Op<"reverse_v2", [
// are unranked. Therefore, we skip adding shape constraints here.
def TFL_SelectOp : TFL_Op<"select", [
NoSideEffect,
SameOperandsAndResultsScale,
PredOpTrait<"operands have same element type", TCopVTEtIsSameAs<1, 2>>,
PredOpTrait<"operands and result have same element type",
TFL_TCresVTEtIsSameAsOp<0, 1>>]> {
@ -2777,8 +2745,7 @@ def TFL_SinOp: TFL_Op<"sin", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Sine operator";
let description = [{
@ -2798,8 +2765,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankRange<0, 1, 4>,
SameOperandsAndResultShape,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> {
FixedOutputRangeInterface]> {
let summary = "Softmax operator";
let description = [{
@ -2834,8 +2800,7 @@ def TFL_SqrtOp: TFL_Op<"sqrt", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Square root operator";
let description = [{
@ -2853,8 +2818,7 @@ def TFL_SquareOp: TFL_Op<"square", [
NoSideEffect,
SameOperandsAndResultShape,
SameOperandsAndResultType,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Square operator";
let description = [{
@ -2907,8 +2871,7 @@ def TFL_SquaredDifferenceOp : TFL_Op<"squared_difference", [
SameOperandsAndResultElementType,
ResultsBroadcastableShape,
NoSideEffect,
NoQuantizableResult,
TFL_GpuTargetOp]> {
NoQuantizableResult]> {
let summary = "Squared difference operator";
let description = [{
@ -2933,8 +2896,7 @@ def TFL_TanhOp: TFL_Op<"tanh", [
SameOperandsAndResultShape,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
FixedOutputRangeInterface,
TFL_GpuTargetOp]> {
FixedOutputRangeInterface]> {
let summary = "Hyperbolic tangent operator";
let description = [{
@ -3035,8 +2997,7 @@ def TFL_TransposeOp : TFL_Op<"transpose", [
TFL_OperandHasRank<1, 1>,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultsScale,
TFL_GpuTargetOp]> {
SameOperandsAndResultsScale]> {
let summary = "Transpose operator";
let description = [{
@ -3170,8 +3131,7 @@ def TFL_SpaceToDepthOp: TFL_Op<"space_to_depth", [
SameOperandsAndResultsScale,
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankAtMost<0, 4>,
TFL_GpuTargetOp
TFL_OperandHasRankAtMost<0, 4>
]> {
let summary = "SpaceToDepth operator";
@ -3383,8 +3343,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
TFL_OperandHasRankAtMost<0, 5>,
TFL_OperandHasRank<1, 1>,
TFL_OperandHasRank<2, 1>,
TFL_OperandHasRank<3, 1>,
TFL_GpuTargetOp
TFL_OperandHasRank<3, 1>
]> {
let summary = "StridedSlice Op";
@ -3434,7 +3393,7 @@ def TFL_CastOp : TFL_Op<"cast", [
}
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
let description = [{
@ -4337,7 +4296,8 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
let hasCanonicalizer = 1;
}
def TFL_CustomOp : Op<TFL_Dialect, "custom", [NoSideEffect]> {
def TFL_CustomOp : Op<TFL_Dialect, "custom", [
NoSideEffect, NoQuantizableResult]> {
let summary = "Custom op";
let description = [{
@ -4360,4 +4320,80 @@ def TFL_CustomOp : Op<TFL_Dialect, "custom", [NoSideEffect]> {
let verifier = [{ return Verify(*this); }];
}
def TFL_CustomTfOp : Op<TFL_Dialect, "custom_tf", [
// Currently the custom ops have no side effect
// TODO(karimnosseir): Revisit if this needs updating.
NoSideEffect,
NoQuantizableResult,
SingleBlockImplicitTerminator<"YieldOp">]> {
let summary = "Wrapper Op for TF custom ops.";
let description = [{
A wrapper op around any Custom TF op. These includes ops defined using
custom_opdefs or linked which are not defined in TF dialect.
This Op just wraps the custom op inside a region.
Note #1, this Op will not include TF Lite custom ops defined using CustomOp.
Note #2, this op is just internal representation inside the converter and
are not exposed/exported when the model is exported to Flatbuffer.
}];
let arguments = (ins
Variadic<TFL_TensorOfOrNone<[AnyType]>>:$input
);
let results = (outs Variadic<AnyTensor>:$output);
let regions = (region SizedRegion<1>:$body);
}
def TFL_BroadcastToOp : TFL_Op<"broadcast_to", [
PredOpTrait<"input and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
TFL_OperandHasRankAtMost<0, 8>,
TFL_OperandHasRank<1, 1>,
PredOpTrait<"output dimension count must be at most 8",
Or<[TFL_OperandIsUnrankedPred<1>,
TFL_OperandDimIsAtMost<1, 0, 8>]>>,
NoSideEffect]> {
let summary = "Broadcast an array for a compatible shape.";
let description = [{
Broadcasting is the process of making arrays to have compatible shapes
for arithmetic operations. Two shapes are compatible if for each
dimension pair they are either equal or one of them is one. When trying
to broadcast a Tensor to a shape, it starts with the trailing dimensions,
and works its way forward.
For example,
>>> x = tf.constant([1, 2, 3])
>>> y = tf.broadcast_to(x, [3, 3])
>>> print(y)
tf.Tensor(
[[1 2 3]
[1 2 3]
[1 2 3]], shape=(3, 3), dtype=int32)
In the above example, the input Tensor with the shape of `[1, 3]`
is broadcasted to output Tensor with shape of `[3, 3]`.
When doing broadcasted operations such as multiplying a tensor
by a scalar, broadcasting (usually) confers some time or space
benefit, as the broadcasted tensor is never materialized.
However, `broadcast_to` does not carry with it any such benefits.
The newly-created tensor takes the full memory of the broadcasted
shape. (In a graph context, `broadcast_to` might be fused to
subsequent operation and then be optimized away, however.)
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$input,
TFL_I32OrI64Tensor:$shape
);
let results = (outs
TFL_TensorOf<[F32, I32, I1, I8, QI8, UI8, QUI8, I16, QI16, I64, Complex<F<32>>]>:$output
);
}
#endif // TFL_OPS

View File

@ -151,10 +151,13 @@ Status ConvertSavedModelToTFLiteFlatBuffer(
return errors::Unimplemented("Only support a single exported name.");
}
tensorflow::GraphImportConfig specs;
specs.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(auto module,
ImportSavedModel(model_flags.saved_model_dir(),
model_flags.saved_model_version(), tags,
exported_names, &context));
exported_names, specs, &context));
if (!model_flags.input_arrays().empty() ||
!model_flags.output_arrays().empty()) {

View File

@ -123,6 +123,8 @@ DataType ConvertIODataTypeToDataType(toco::IODataType dtype) {
return DT_BOOL;
case toco::IODataType::COMPLEX64:
return DT_COMPLEX64;
case toco::IODataType::COMPLEX128:
return DT_COMPLEX128;
default:
return DT_INVALID;
}

View File

@ -81,7 +81,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:Analysis",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:QuantOps",

View File

@ -36,11 +36,11 @@ versions {
producer: 27
}
# CHECK-LABEL: func @main
# CHECK-SAME: (%[[ARG_0:[a-z0-9]+]]: tensor<4xi32>, %[[ARG_1:[a-z0-9]+]]: tensor<4xi32>) -> tensor<*xi32>
# CHECK-SAME: control_outputs = ""
# CHECK-SAME: inputs = "input0,input1"
# CHECK-SAME: outputs = "output"
# CHECK-NEXT: %[[OP:[a-z0-9]+]] = "tf.BannaPotatoSaladWithColeslaw"(%[[ARG_0]], %[[ARG_1]]) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %[[OP]] : tensor<*xi32>
# CHECK-NEXT: }
# CHECK-LABEL: func @main(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<*xi32>
# CHECK: attributes {tf.entry_function = {control_outputs = "", inputs = "input0,input1", outputs = "output"}} {
# CHECK-NEXT: %[[CUSTOM:.*]] = "tfl.custom_tf"(%arg0, %arg1) ( {
# CHECK-NEXT: %[[OUTPUTS:.*]] = "tf.BannaPotatoSaladWithColeslaw"(%arg0, %arg1) {T = i32, device = ""} : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: "tfl.yield"(%[[OUTPUTS]]) : (tensor<*xi32>) -> ()
# CHECK-NEXT: }) : (tensor<4xi32>, tensor<4xi32>) -> tensor<*xi32>
# CHECK-NEXT: return %[[CUSTOM]] : tensor<*xi32>
# CHECK-NEXT: }

View File

@ -15,6 +15,13 @@ func @complex64() -> tensor<4xcomplex<f32>> {
return %0 : tensor<4xcomplex<f32>>
}
func @complex128() -> tensor<4xcomplex<f64>> {
// CHECK-LABEL: @complex128
// CHECK: value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>>
%0 = "tfl.pseudo_const"() { value = opaque<"tf", "0x746674656E736F722464747970653A2044545F434F4D504C45583132382074656E736F725F7368617065207B2064696D207B2073697A653A2034207D207D2074656E736F725F636F6E74656E743A20225C3030305C3030305C3030305C3030305C3030305C3030305C3336303F5C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303130405C3030305C3030305C3030305C3030305C3030305C3030305C303030405C3030305C3030305C3030305C3030305C3030305C3030305C303230405C3030305C3030305C3030305C3030305C3030305C3030305C3030304022"> : tensor<4xcomplex<f64>> } : () -> tensor<4xcomplex<f64>>
return %0 : tensor<4xcomplex<f64>>
}
// TODO(b/138847107) this should work but doesn't
// func @f16() -> tensor<4xf16> {
// %0 = "tfl.pseudo_const"() { value = dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf16> } : () -> tensor<4xf16>

View File

@ -1,7 +1,7 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s -split-input-file | FileCheck %s
module {
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -1027,11 +1027,11 @@ module {
return %1 : tensor<i1>
}
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
@ -2160,11 +2160,12 @@ module {
}
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<?x1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -3190,7 +3191,7 @@ module {
return %1 : tensor<i1>
}
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf.api_implements = "tftext:WhitespaceTokenizer", tf.signature.is_stateful} {
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
// CHECK: return %0 : tensor<?x!tf.string>
}

View File

@ -5,8 +5,6 @@ func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> tensor<3
return %0: tensor<3x3xbf16>
// CHECK-LABEL: broadcast_to_bf16
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<bf16>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi64>, tensor<bf16>) -> tensor<3x3xbf16>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16>
// CHECK: return [[MUL]] : tensor<3x3xbf16>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16>
// CHECK: return [[BCT]] : tensor<3x3xbf16>
}

View File

@ -1487,10 +1487,8 @@ func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<f32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<f32>) -> tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
// CHECK: return [[BCT]] : tensor<3x3xf32>
}
func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
@ -1498,10 +1496,8 @@ func @broadcast_to_i32(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tfl.fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tfl.mul"(%arg0, [[FILL]]) {fused_activation_function = "NONE"} : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
// CHECK: [[BCT:%.*]] = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
// CHECK: return [[BCT]] : tensor<3x3xi32>
}
func @matmul_batch(%arg0: tensor<10x15xf32>, %arg1: tensor<15x17xf32>) -> tensor<10x17xf32> {

View File

@ -277,6 +277,45 @@ func @tensorlistWhileCond(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> ten
// CHECK: return %[[RESULT]] : tensor<i1>
}
// CHECK-LABEL: func @tensorlistWhileRegion
func @tensorlistWhileRegion(%arg0: tensor<2x3xf32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>
%cst_1 = constant dense<-1> : tensor<i32>
%0 = "tf.TensorListFromTensor"(%arg0, %cst) : (tensor<2x3xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<3xf32>>>
// CHECK: "tf.WhileRegion"
%1:2 = "tf.WhileRegion"(%cst_0, %0) ({
^bb0(%carg0: tensor<i32>, %carg1: tensor<!tf.variant>):
%cst_2 = constant dense<2> : tensor<i32>
%1 = "tf.Less"(%carg0, %cst_2) : (tensor<i32>, tensor<i32>) -> tensor<i1>
"tf.Yield"(%1) : (tensor<i1>) -> ()
// verify condition types
// CHECK: ^bb0(%[[CARG0:.*]]: tensor<i32>, %[[CARG1:.*]]: tensor<*xf32>):
// CHECK: %[[COND:.*]] = "tf.Less"(%[[CARG0]], {{.*}}) : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK: "tf.Yield"(%[[COND]]) : (tensor<i1>) -> ()
},
{
^bb0(%barg0: tensor<i32>, %barg1: tensor<!tf.variant>):
%1 = "tf.TensorListLength"(%barg1) : (tensor<!tf.variant>) -> tensor<i32>
"tf.Yield"(%1, %barg1) : (tensor<i32>, tensor<!tf.variant>) -> ()
// verify body types
// CHECK: ^bb0(%[[BARG0:.*]]: tensor<i32>, %[[BARG1:.*]]: tensor<*xf32>):
// CHECK-NOT: tensor<!tf.variant>
// CHECK: %[[LEN:.*]] = "tf.Gather"
// CHECK-NOT: tensor<!tf.variant>
// CHECK: "tf.Yield"(%[[LEN]], %[[BARG1]]) : (tensor<i32>, tensor<*xf32>) -> ()
}) {is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<3xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
// make sure the variant types in input/output have been updated
// CHECK: {is_stateless = false} : (tensor<i32>, tensor<2x3xf32>) -> (tensor<i32>, tensor<*xf32>)
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
// CHECK: return %0#1 : tensor<*xf32>
return %2 : tensor<*xf32>
}
func @tensorlistResize(%arg0: tensor<3x10xf32>, %arg1: tensor<1xi32>, %arg2: tensor<i32>) -> tensor<?x10xf32> {
%0 = "tf.TensorListFromTensor"(%arg0, %arg1) : (tensor<3x10xf32>, tensor<1xi32>) -> tensor<!tf.variant<tensor<10xf32>>>
%1 = "tf.TensorListResize"(%0, %arg2) : (tensor<!tf.variant<tensor<10xf32>>>, tensor<i32>) -> tensor<!tf.variant<tensor<10xf32>>>

View File

@ -0,0 +1,66 @@
// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s -emit-select-tf-ops -o - | flatbuffer_to_string - | FileCheck %s
func @main(tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> {
^bb0(%arg0: tensor<4xcomplex<f64>>, %arg1: tensor<4xcomplex<f64>>):
// CHECK: {
// CHECK-NEXT: version: 3,
// CHECK-NEXT: operator_codes: [ {
// CHECK-NEXT: builtin_code: CUSTOM,
// CHECK-NEXT: custom_code: "FlexAdd"
// CHECK-NEXT: } ],
// CHECK-NEXT: subgraphs: [ {
// CHECK-NEXT: tensors: [ {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: COMPLEX128,
// CHECK-NEXT: buffer: 1,
// CHECK-NEXT: name: "arg0",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: COMPLEX128,
// CHECK-NEXT: buffer: 2,
// CHECK-NEXT: name: "arg1",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: }, {
// CHECK-NEXT: shape: [ 4 ],
// CHECK-NEXT: type: COMPLEX128,
// CHECK-NEXT: buffer: 3,
// CHECK-NEXT: name: "add",
// CHECK-NEXT: quantization: {
// CHECK-EMPTY:
// CHECK-NEXT: }
// CHECK-NEXT: } ],
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: operators: [ {
// CHECK-NEXT: inputs: [ 0, 1 ],
// CHECK-NEXT: outputs: [ 2 ],
// CHECK-NEXT: custom_options: [ 3, 65, 100, 100, 0, 20, 18, 3, 65, 100, 100, 26, 0, 26, 0, 42, 7, 10, 1, 84, 18, 2, 48, 18, 50, 0, 0, 2, 27, 23, 20, 20, 4, 40, 1 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: name: "main"
// CHECK-NEXT: } ],
// CHECK-NEXT: description: "MLIR Converted.",
// CHECK-NEXT: buffers: [ {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-EMPTY:
// CHECK-NEXT: }, {
// CHECK-NEXT: data: [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]
// CHECK-NEXT: } ],
// CHECK-NEXT: metadata: [ {
// CHECK-NEXT: name: "min_runtime_version",
// CHECK-NEXT: buffer: 4
// CHECK-NEXT: } ]
// CHECK-NEXT:}
%0 = "tf.Add"(%arg0, %arg1) : (tensor<4xcomplex<f64>>, tensor<4xcomplex<f64>>) -> tensor<4xcomplex<f64>> loc("add")
return %0 : tensor<4xcomplex<f64>>
}

View File

@ -598,6 +598,16 @@ func @testMaxPool2DWrongOperandStorageType(tensor<1x7x7x16x!quant.uniform<i9:f32
// -----
func @testTFLiteDetectionPostProcess(%arg0: tensor<1x64x64x32xf32>, %arg1: tensor<1x64x64x32xf32>, %arg2: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
%0, %1, %2, %3 = "tfl.custom_tf"(%arg0, %arg1, %arg2) ({
%4, %5, %6, %7 = "tf.TFLite_Detection_PostProcess"(%arg0, %arg1, %arg2) {_output_quantized = true, _output_types = [f32, f32, f32, f32], _support_output_type_float_in_quantized_op = true, detections_per_class = 100 : i64, device = "", h_scale = 5.000000e+00 : f32, max_classes_per_detection = 1 : i64, max_detections = 20 : i64, nms_iou_threshold = 6.000000e-01 : f32, nms_score_threshold = 3.000000e-01 : f32, num_classes = 90 : i64, use_regular_nms = false, w_scale = 5.000000e+00 : f32, x_scale = 1.000000e+01 : f32, y_scale = 1.000000e+01 : f32} : (tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
"tfl.yield"(%4, %5, %6, %7) : (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) -> ()
}) : (tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>, tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
return %0, %1 : tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>
}
// -----
func @testMaxPoolingWithArgMax2D(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>) {
// custom op for "tfl.max_pooling_with_argmax_2d"(%arg0) {filter_h = 2 : i32, filter_w = 2 : i32, padding = "SAME", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
%0, %1 = "tfl.custom"(%arg0) {custom_option = opaque<"tfl", "0x01000000020000000200000002000000020000000000000000000000000000000000000000000000"> : tensor<40xi8>, custom_code = "MaxPoolingWithArgmax2D"} : (tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32xf32>, tensor<1x32x32x32xf32>)
@ -2300,3 +2310,21 @@ func @main(%arg0: tensor<i32>, %arg1: tensor<1xf32>) -> tensor<i32> {
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>)
return %0#0 : tensor<i32>
}
// -----
// CHECK-LABEL: testBroadcastToWithI32ShapeTensor
func @testBroadcastToWithI32ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32> {
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi32>):
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi32>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}
// CHECK-LABEL: testBroadcastToWithI64ShapeTensor
func @testBroadcastToWithI64ShapeTensor(tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32> {
^bb0(%arg0: tensor<?x?x?x?x?x?xf32>, %arg1: tensor<8xi64>):
// CHECK: "tfl.broadcast_to"(%arg0, %arg1)
%0 = "tfl.broadcast_to"(%arg0, %arg1): (tensor<?x?x?x?x?x?xf32>, tensor<8xi64>) -> tensor<?x?x?x?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?x?x?x?xf32>
}

View File

@ -992,3 +992,13 @@ func @RemoveCast(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
// CHECK: return %arg0
}
func @squaredDifferenceReluRemoveRelu(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
%0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "tfl.relu"(%0) : (tensor<1xf32>) -> tensor<1xf32>
return %1: tensor<1xf32>
// CHECK-LABEL: squaredDifferenceReluRemoveRelu
// CHECK: %[[RESULT:.*]] = tfl.squared_difference %arg0, %arg1 : tensor<1xf32>
// CHECK: return %[[RESULT]]
}

View File

@ -481,3 +481,15 @@ func @nms_padded_invalid_num_args(%arg0: tensor<100x4xf32>, %arg1: tensor<100xf3
// expected-error @+1 {{TFLite does not support batched input for non_max_suppression_padded}}
func @nms_padded_with_batches(%arg0: tensor<2x100x4xf32>, %arg1: tensor<2x100xf32>, %arg2: tensor<i32>, %arg3: tensor<f32>, %arg4: tensor<f32>, %arg5: tensor<i1>, %arg6: tensor<i1>, %arg7: tensor<i1>, %arg8: tensor<i32>) -> (tensor<2x10xi32>, tensor<i32>) attributes {tf._implements = "non_max_suppression_padded_v2", tf._reference = "mlir"}
}
// -----
module {
// CHECK-LABEL: func @some_func
// CHECK-LABEL: func @func_with_call
func @some_func(%arg0: tensor<100xf32>) -> tensor<100xf32> attributes {tf.api_implements = "lstm_b4e9f0e7-ac55-42bc-8ef2-8496419a608c"}
func @func_with_call(%arg0: tensor<100xf32>) -> tensor<100xf32> {
%0 = call @some_func(%arg0) : (tensor<100xf32>) -> tensor<100xf32>
return %0 : tensor<100xf32>
}
}

View File

@ -578,3 +578,70 @@ func @MatrixSetDiagV3Conversion(%arg0: tensor<3x3xi32>, %arg1: tensor<3xi32>) ->
// CHECK: %[[RES:.*]] = "tf.MatrixSetDiag"(%arg0, %arg1) : (tensor<3x3xi32>, tensor<3xi32>) -> tensor<3x3xi32>
// CHECK: return %[[RES]]
}
func @broadcast_to_f32_low_dim(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_f32_low_dim
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32_low_dim(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<3x3xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32>
return %0: tensor<3x3xi32>
// CHECK-LABEL: broadcast_to_i32_low_dim
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<3x3xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32>
// CHECK: return [[MUL]] : tensor<3x3xi32>
}
func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32>
return %0: tensor<3x3xf32>
// CHECK-LABEL: broadcast_to_low_dim_with_unknown_shape
// CHECK: [[CST:%.*]] = constant dense<1.000000e+00> : tensor<3x3xf32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[CST]]) : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32>
// CHECK: return [[MUL]] : tensor<3x3xf32>
}
func @broadcast_to_i32_low_dim_with_unknown_output(%input: tensor<3xi32>, %shape: tensor<2xi32>) -> tensor<*xi32> {
%0 = "tf.BroadcastTo"(%input, %shape) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32>
return %0: tensor<*xi32>
// CHECK-LABEL: broadcast_to_i32_low_dim_with_unknown_output
// CHECK: [[CST:%.*]] = constant dense<1> : tensor<i32>
// CHECK: [[FILL:%.*]] = "tf.Fill"(%arg1, [[CST]]) : (tensor<2xi32>, tensor<i32>) -> tensor<*xi32>
// CHECK: [[MUL:%.*]] = "tf.Mul"(%arg0, [[FILL]]) : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32>
// CHECK: return [[MUL]] : tensor<*xi32>
}
func @broadcast_to_high_dim_with_unknown_shape(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
return %0: tensor<7x8x1x2x3x4x5x6xf32>
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_shape
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<7x8x1x2x3x4x5x6xf32>
// CHECK: return [[BCT]] : tensor<7x8x1x2x3x4x5x6xf32>
}
func @broadcast_to_high_dim_with_unknown_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<8xi32>) -> tensor<*xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
return %0: tensor<*xf32>
// CHECK-LABEL: broadcast_to_high_dim_with_unknown_output
// CHECK: [[BCT:%.*]] = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<8xi32>) -> tensor<*xf32>
// CHECK: return [[BCT]] : tensor<*xf32>
}
func @broadcast_to_with_unknown_shape_and_output(%arg0: tensor<1x2x3x4x5x6xf32>, %arg1: tensor<*xi32>) -> tensor<*xf32> {
%0 = "tf.BroadcastTo"(%arg0, %arg1) : (tensor<1x2x3x4x5x6xf32>, tensor<*xi32>) -> tensor<*xf32>
return %0: tensor<*xf32>
// CHECK-LABEL: broadcast_to_with_unknown_shape_and_output
// CHECK: "tf.BroadcastTo"(%arg0, %arg1)
}

View File

@ -0,0 +1,20 @@
// RUN: tf-opt -tfl-raise-custom-ops -canonicalize %s -o - | FileCheck %s
// CHECK-LABEL: custom_op
func @custom_op(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = "tfl.pseudo_const" () {value = dense<1.0> : tensor<4xf32>} : () -> tensor<4xf32>
%1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// will be preserved since it has uses.
%2 = "tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// will be removed since it doesn't have uses and doesn't have side effect.
"tf.MyCustomOp"(%1, %0) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
return %2 : tensor<4xf32>
// CHECK-NEXT: %[[CST:.*]] = constant dense<1.000000e+00>
// CHECK-NEXT: %[[MUL:.*]] = tfl.mul %arg0, %[[CST]] {fused_activation_function = "NONE"} : tensor<4xf32>
// CHECK-NEXT: %[[CUSTOM:.*]] = "tfl.custom_tf"(%[[MUL]], %[[CST]]) ( {
// CHECK-NEXT: %[[MY_CUSTOM:.*]] = "tf.MyCustomOp"(%[[MUL]], %[[CST]]) {fused_activation_function = "RELU", int_attr = 2 : i32} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: "tfl.yield"(%[[MY_CUSTOM]]) : (tensor<4xf32>) -> ()
// CHECK-NEXT: }) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32>
// CHECK-NEXT: return %[[CUSTOM]] : tensor<4xf32>
}

View File

@ -175,6 +175,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// Add a shape inference pass to optimize away the unnecessary casts.
pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
}
// Inline function calls that left in the graph after folding functional
// control flow ops (IfOp, CaseOp).
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(
mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification));
pass_manager->addPass(mlir::TFL::CreateOptimizePass());
@ -182,6 +187,7 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// so that it can target constants introduced once TensorFlow Identity ops
// are removed during legalization.
pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
pass_manager->addPass(mlir::TFL::CreateRaiseCustomOpsPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
pass_manager->addNestedPass<mlir::FuncOp>(mlir::createCSEPass());

Some files were not shown because too many files have changed in this diff Show More