merge with master

This commit is contained in:
Daniel Nguyen 2020-08-13 21:20:00 +00:00
commit fab242f3dc
1510 changed files with 52824 additions and 15780 deletions
.bazelrcRELEASE.mdconfigure.cmd
tensorflow
api_template.__init__.pyapi_template_v1.__init__.py
c
cc
compiler

View File

@ -18,8 +18,10 @@
#
# Compiler options:
# cuda_clang: Use clang when building CUDA code.
# c++17: Build with C++17 options
# c++1z: Build with C++17 options
# c++17: Build with C++17 options (links with libc++)
# c++1z: Build with C++17 options (links with libc++)
# c++17_gcc: Build with C++17 options (links with stdlibc++)
# c++1z_gcc: Build with C++17 options (links with stdlibc++)
# avx_linux: Build with avx instruction set on linux.
# avx2_linux: Build with avx2 instruction set on linux.
# native_arch_linux: Build with instruction sets available to the host machine on linux
@ -165,8 +167,18 @@ build:mkl -c opt
# config to build OneDNN backend with a user specified threadpool.
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_dnn_v1_only=true
build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_dnn_v1_only=true
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only -c opt
# This config refers to building with CUDA available. It does not necessarily
# mean that we build CUDA op kernels.
build:using_cuda --define=using_cuda=true
@ -268,6 +280,8 @@ build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
build:c++17 --cxxopt=-std=c++1z
build:c++17 --cxxopt=-stdlib=libc++
build:c++1z --config=c++17
build:c++17_gcc --cxxopt=-std=c++1z
build:c++1z_gcc --config=c++17_gcc
# Enable using platform specific build settings, except when cross-compiling for
# mobile platforms.
@ -358,7 +372,6 @@ build --config=v2
test --config=v2
# Enable XLA
build:xla --action_env=TF_ENABLE_XLA=1
build:xla --define=with_xla_support=true
# BEGIN TF REMOTE BUILD EXECUTION OPTIONS

View File

@ -22,6 +22,7 @@
* Code that uses `tf.map_fn`/`tf.cond`/`tf.while_loop`/control flow as op layers and happens to work before TF 2.4. These will explicitly be unsupported now. Converting these ops to Functional API op layers was unreliable before TF 2.4, and prone to erroring incomprehensibly or being silently buggy.
* Code that directly asserts on a Keras symbolic value in cases where ops like `tf.rank` used to return a static or symbolic value depending on if the input had a fully static shape or not. Now these ops always return symbolic values.
* Code already susceptible to leaking tensors outside of graphs becomes slightly more likely to do so now.
* Code that tries directly getting gradients with respect to symbolic Keras inputs/outputs. Use GradientTape on the actual Tensors passed to the already-constructed model instead.
* Code that requires very tricky shape manipulation via converted op layers in order to work, where the Keras symbolic shape inference proves insufficient.
* Code that tries manually walking a `tf.keras.Model` layer by layer and assumes layers only ever have one positional argument. This assumption doesn't hold true before TF 2.4 either, but is more likely to cause issues know.
* Code that manually enters `keras.backend.get_graph()` before building a functional model. This is no longer needed.
@ -33,6 +34,9 @@
shape assumptions (note that you can pass shapes with `None` entries for axes
that are meant to be dynamic). You can also disable the input checking
entirely by setting `model.input_spec = None`.
* XLA:CPU and XLA:GPU devices are no longer registered by default. Use
`TF_XLA_FLAGS=--tf_xla_enable_xla_devices` if you really need them (to be
removed).
## Known Caveats
@ -72,6 +76,11 @@
`tf.data.experimental.service.from_dataset_id` APIs to enable one process
to register a dataset with the tf.data service, and another process to
consume data from the dataset.
* Added support for tf.data service dispatcher fault tolerance. To enable
fault tolerance, configure a `work_dir` when running your dispatcher
server and set `dispatcher_fault_tolerance=True`. The dispatcher will
store its state to `work_dir`, so that on restart it can continue from its
previous state after restart.
* Added optional `exclude_cols` parameter to CsvDataset. This parameter is
the complement of `select_cols`; at most one of these should be specified.
* We have implemented an optimization which reorders data-discarding
@ -81,9 +90,11 @@
option.
* `tf.image`:
* Added deterministic `tf.image.stateless_random_*` functions for each
`tf.image.random_*` function. Given the same seed, the stateless functions
produce the same results independent of how many times the function is
called, and independent of global seed settings.
`tf.image.random_*` function. Added a new op
`stateless_sample_distorted_bounding_box` which is a determinstic
version of `sample_distorted_bounding_box` op. Given the same seed, these
stateless functions/ops produce the same results independent of how many
times the function is called, and independent of global seed settings.
* `tf.distribute`:
* <ADD RELEASE NOTES HERE>
* `tf.keras`:
@ -95,16 +106,39 @@
* Error messages when Functional API construction goes wrong (and when ops cannot be converted to Keras layers automatically) should be clearer and easier to understand.
* `Optimizer.minimize` can now accept a loss `Tensor` and a `GradientTape`
as an alternative to accepting a `callable` loss.
* Added `beta` parameter to FTRL optimizer to match paper.
* Added `mobilenet_v3` to keras application model.
* `Optimizer.__init__` now accepts a `gradient_aggregator` to allow for
customization of how gradients are aggregated across devices, as well as
`gradients_transformers` to allow for custom gradient transformations
(such as gradient clipping).
* `tf.function` / AutoGraph:
* Added `experimental_follow_type_hints` argument for `tf.function`. When
True, the function may use type annotations to optimize the tracing
performance.
* Added support for `iter(DistributedDataset)` in AutoGraph `for` loops.
* AutoGraph now allows creating new symbols inside a TensorFLow loop, if
the values of these symbols at an iteration does not depend on the previous
iteration. These types of loops must run at least one iteration, and will
raise a runtime error otherwise.
Example:
```
for batch in data:
outputs = train_step(batch)
tf.print('final outputs', outputs)
```
See tensorflow/python/autograph/g3doc/reference/limitations.md for more
info.
* `tf.lite`:
* `DynamicBuffer::AddJoinedString()` will now add a separator if the first
string to be joined is empty.
* `TFLiteConverter`:
* Support optional flags `inference_input_type` and `inference_output_type` for full integer quantized models. This allows users to modify the model input and output type to integer types (`tf.int8`, `tf.uint8`) instead of defaulting to float type (`tf.float32`).
* Deprecate `Interpreter::UseNNAPI(bool)` C++ API
* Prefer using `NnApiDelegate()` and related delegate configuration methods directly.
* Add NNAPI Delegation support for requantization use cases by converting the operation into a dequantize-quantize pair.
* <ADD RELEASE NOTES HERE>
* `tf.random`:
* <ADD RELEASE NOTES HERE>
@ -116,6 +150,8 @@
behavior by adjusting the `l2` parameter.
* <ADD RELEASE NOTES HERE>
* XLA Support:
* xla.experimental.compile is deprecated, use
`tf.function(experimental_compile=True)` instead
* <ADD RELEASE NOTES HERE>
* Tracing and Debugging:
* <ADD RELEASE NOTES HERE>

View File

@ -16,5 +16,5 @@
set configure_dir=%~dp0
set configure_dir=%configure_dir:~0,-1%
python %configure_dir%\configure.py %* || ( exit /b )
python "%configure_dir%\configure.py" %* || ( exit /b )
echo Configuration finished

View File

@ -137,7 +137,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)

View File

@ -147,7 +147,7 @@ if _running_from_pip_package():
# TODO(gunan): Add sanity checks to loaded modules here.
for _s in _site_packages_dirs:
# Load first party dynamic kernels.
_main_dir = _os.path.join(_s, 'tensorflow_core/core/kernels')
_main_dir = _os.path.join(_s, 'tensorflow/core/kernels')
if _fi.file_exists(_main_dir):
_ll.load_library(_main_dir)

View File

@ -23,6 +23,7 @@ filegroup(
srcs = [
"c_api.h",
"c_api_experimental.h",
"c_api_macros.h",
"tensor_interface.h",
"tf_attrtype.h",
"tf_datatype.h",
@ -61,6 +62,7 @@ filegroup(
name = "pywrap_required_hdrs",
srcs = [
"c_api_internal.h",
"c_api_macros.h",
"conversion_macros.h",
"python_api.h",
"tensor_interface.h",
@ -79,6 +81,7 @@ tf_cuda_library(
hdrs = [
"c_api.h",
"c_api_internal.h",
"c_api_macros.h",
"tf_datatype.h",
"tf_tensor.h",
"tf_tstring.h",
@ -310,6 +313,7 @@ cc_library(
hdrs = ["tf_tensor.h"],
visibility = ["//visibility:public"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",
@ -336,6 +340,7 @@ tf_cuda_library(
],
visibility = ["//tensorflow:internal"],
deps = [
":c_api_macros",
":tensor_interface",
":tf_datatype",
":tf_status",

View File

@ -213,7 +213,6 @@ void TF_Reset(const TF_SessionOptions* opt, const char** containers,
namespace tensorflow {
Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in,
TF_Buffer* out) {
if (out->data != nullptr) {
@ -306,8 +305,8 @@ void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output,
}
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
Status LoadDynamicLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
// TODO(josh11b,mrry): Change Session to be able to use a Graph*
// directly, instead of requiring us to serialize to a GraphDef and
@ -552,7 +551,7 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle,
TF_Library* TF_LoadLibrary(const char* library_filename, TF_Status* status) {
TF_Library* lib_handle = new TF_Library;
status->status = tensorflow::LoadLibrary(
status->status = tensorflow::LoadDynamicLibrary(
library_filename, &lib_handle->lib_handle, &lib_handle->op_list.data,
&lib_handle->op_list.length);
if (!status->status.ok()) {

View File

@ -30,4 +30,17 @@ limitations under the License.
#endif // _WIN32
#endif // SWIG
// TF_Bool is the C API typedef for unsigned char, while TF_BOOL is
// the datatype for boolean tensors.
#ifndef TF_Bool
#define TF_Bool unsigned char
#endif // TF_Bool
// Macro used to calculate struct size for maintaining ABI stability across
// different struct implementations.
#ifndef TF_OFFSET_OF_END
#define TF_OFFSET_OF_END(TYPE, MEMBER) \
(offsetof(TYPE, MEMBER) + sizeof(((TYPE *)0)->MEMBER))
#endif // TF_OFFSET_OF_END
#endif // TENSORFLOW_C_C_API_MACROS_H_

View File

@ -249,6 +249,7 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@ -508,6 +509,27 @@ tf_cuda_cc_test(
],
)
tf_cuda_library(
name = "c_api_remote_test_util",
testonly = 1,
srcs = ["c_api_remote_test_util.cc"],
hdrs = ["c_api_remote_test_util.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":c_api",
":c_api_internal",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"@com_google_absl//absl/strings",
],
)
tf_cuda_cc_test(
name = "c_api_remote_test",
size = "small",
@ -524,6 +546,7 @@ tf_cuda_cc_test(
":c_api",
":c_api_experimental",
":c_api_internal",
":c_api_remote_test_util",
":c_api_test_util",
":tfe_tensorhandle_internal",
"//tensorflow/c:c_test_util",
@ -540,6 +563,25 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_remote_function_test",
size = "small",
srcs = [
"c_api_remote_function_test.cc",
],
# TODO(b/136478427): Figure out how to correctly shut the server down
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
tags = [
"no_windows",
],
deps = [
":c_api_remote_test_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
tf_cuda_cc_test(
name = "c_api_distributed_test",
size = "small",

View File

@ -724,7 +724,7 @@ void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
if (opts->use_tfrt) {
#ifdef PLATFORM_GOOGLE
return tensorflow::wrap(new tfrt::ContextInterface(opts->async));
return tensorflow::wrap(new tfrt::tf::ContextInterface(opts->async));
#else
status->status = tensorflow::errors::Unimplemented("TFRT is not supported");
return nullptr;

View File

@ -518,7 +518,8 @@ void TestDistributedFunctionCancellation(bool inject_error) {
TFE_TensorHandle* var_handle = TestVariable(ctx, 2.0, dev2_name);
EXPECT_NE(var_handle, nullptr);
const string function_def = VariableAddFunctionWithGraphError();
const string function_def = inject_error ? VariableAddFunctionWithGraphError()
: VariableAddFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);

View File

@ -0,0 +1,56 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace {
void TestRemoteExecuteSilentCopiesFunc(bool async, bool remote,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/true,
heavy_load_on_streaming_rpc,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
// TODO(b/162618595): Enable this test once we remove the check of remote
// outputs in ProcessFunctionLibraryRuntime.
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/false, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, DISABLED_RemoteExecuteSilentCopiesLocalAsyncFuncRemoteOutputs) {
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/false,
/*remote_func_outputs=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopiesFunc(/*async=*/true, /*remote=*/false,
/*heavy_load_on_streaming_rpc=*/true);
}
} // namespace

View File

@ -13,9 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
@ -115,225 +117,24 @@ void TestRemoteExecute(bool async) {
TEST(CAPI, RemoteExecute) { TestRemoteExecute(false); }
TEST(CAPI, RemoteExecuteAsync) { TestRemoteExecute(true); }
string MatMulFunction() {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }",
&def));
return def.SerializeAsString();
}
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
void TestRemoteExecuteSilentCopiesOp(bool async, bool remote,
bool remote_func_outputs = false) {
return TestRemoteExecuteSilentCopies(async, remote, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false,
remote_func_outputs);
}
TEST(CAPI, RemoteExecuteSilentCopies) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/true,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/true, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/true);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocal) {
TestRemoteExecuteSilentCopies(/*async=*/false, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
TestRemoteExecuteSilentCopiesOp(/*async=*/false, /*remote=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsync) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false,
/*func=*/false,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFunc) {
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/false);
}
TEST(CAPI, RemoteExecuteSilentCopiesLocalAsyncFuncOrdering) {
// A remote input may be not ready when we start running a function. Test that
// the function execution should wait until the remote input is ready.
TestRemoteExecuteSilentCopies(/*async=*/true, /*remote=*/false, /*func=*/true,
/*heavy_load_on_streaming_rpc=*/true);
TestRemoteExecuteSilentCopiesOp(/*async=*/true, /*remote=*/false);
}
} // namespace

View File

@ -0,0 +1,215 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_remote_test_util.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
using ::tensorflow::string;
string MatMulFunction(const string& matmul_device) {
tensorflow::FunctionDef def;
CHECK(tensorflow::protobuf::TextFormat::ParseFromString(
absl::StrCat(" signature {"
" name: 'MatMulFunction'"
" input_arg {"
" name: 'a'"
" type: DT_FLOAT"
" }"
" input_arg {"
" name: 'b'"
" type: DT_FLOAT"
" }"
" output_arg {"
" name: 'm'"
" type: DT_FLOAT"
" }"
" }"
" node_def {"
" name: 'matmul'"
" op: 'MatMul'"
" input: 'a'"
" input: 'b'"
" device: '",
matmul_device, "'",
" attr {"
" key: 'T'"
" value {"
" type: DT_FLOAT"
" }"
" }"
" }"
" ret {"
" key: 'm'"
" value: 'matmul:product'"
" }"),
&def));
return def.SerializeAsString();
}
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs) {
tensorflow::ServerDef server_def = GetServerDef(3);
// This server def has the task index set to 0.
string serialized = server_def.SerializeAsString();
server_def.set_task_index(1);
std::unique_ptr<tensorflow::GrpcServer> worker_server1;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server1)
.ok());
ASSERT_TRUE(worker_server1->Start().ok());
server_def.set_task_index(2);
std::unique_ptr<tensorflow::GrpcServer> worker_server2;
ASSERT_TRUE(tensorflow::GrpcServer::Create(
server_def, tensorflow::Env::Default(), &worker_server2)
.ok());
ASSERT_TRUE(worker_server2->Start().ok());
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_ContextSetServerDef(ctx, 0, serialized.data(), serialized.size(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_TensorHandle* h0_task0 = TestMatrixTensorHandle(ctx);
TFE_TensorHandle* h1_task0 = TestMatrixTensorHandle(ctx);
std::vector<TFE_TensorHandle*> handles_task0;
if (heavy_load_on_streaming_rpc) {
// Send 50 tensor copy requests to simulate that there have been some RPC
// requests been enqueued.
for (int i = 0; i < 50; ++i) {
handles_task0.push_back(TestMatrixTensorHandle(ctx));
}
}
const char task1_name[] = "/job:localhost/replica:0/task:1/device:CPU:0";
const char task2_name[] = "/job:localhost/replica:0/task:2/device:CPU:0";
std::vector<TFE_TensorHandle*> handles_task2;
for (auto* h_task0 : handles_task0) {
handles_task2.push_back(
TFE_TensorHandleCopyToDevice(h_task0, ctx, task2_name, status));
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
}
auto* h1_task2 =
TFE_TensorHandleCopyToDevice(h1_task0, ctx, task2_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_Op* matmul = nullptr;
if (func) {
const string matmul_device = remote_func_outputs ? task2_name : "";
string function_def = MatMulFunction(matmul_device);
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
status);
CHECK_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
matmul = TFE_NewOp(ctx, "MatMulFunction", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h0_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_OpAddInput(matmul, h1_task2, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else {
// Handles are on task0 (local), and task2, but op is on task1.
matmul = MatMulOp(ctx, h0_task0, h1_task2);
}
if (remote) {
TFE_OpSetDevice(matmul, task1_name, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
} else if (!async) {
// Set the local device to CPU to easily validate mirroring
string cpu_device_name;
ASSERT_TRUE(GetDeviceName(ctx, &cpu_device_name, "CPU"));
TFE_OpSetDevice(matmul, cpu_device_name.c_str(), status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_FALSE(remote_arg->HasLocalMirror(nullptr));
}
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
// TODO(gjn): Add support for waiting on async local mirrors
if (!remote && !async && !remote_func_outputs) {
auto remote_arg =
tensorflow::TensorHandleFromInterface(tensorflow::unwrap(h1_task2));
// The input handles should never change since they have been mirrored.
ASSERT_TRUE(remote_arg->HasLocalMirror(nullptr));
}
auto* retval_task0 = TFE_TensorHandleCopyToDevice(
retvals[0], ctx, "/job:localhost/replica:0/task:0/device:CPU:0", status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TF_Tensor* t = TFE_TensorHandleResolve(retval_task0, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteTensorHandle(retval_task0);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TFE_DeleteTensorHandle(h0_task0);
TFE_DeleteTensorHandle(h1_task0);
TFE_DeleteTensorHandle(h1_task2);
TFE_DeleteTensorHandle(retvals[0]);
for (auto* h : handles_task0) {
TFE_DeleteTensorHandle(h);
}
for (auto* h : handles_task2) {
TFE_DeleteTensorHandle(h);
}
TFE_DeleteOp(matmul);
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_GetCode(status), TF_OK) << TF_Message(status);
TFE_DeleteExecutor(executor);
if (func) {
TFE_ContextRemoveFunction(ctx, "MatMulFunction", status);
}
TFE_DeleteContext(ctx);
TF_DeleteStatus(status);
// TODO(b/136478427): Figure out how to correctly shut the server down.
worker_server1.release();
worker_server2.release();
}

View File

@ -12,25 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
#define TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// Run a function containing a MatMul op and check its output.
// If heavy_load_on_streaming_rpc is true, send some rpc reqeusts before the one
// which creates a remote remote input, to simulate a scenario that the remote
// input is not ready when we start running an op or a function.
void TestRemoteExecuteSilentCopies(bool async, bool remote, bool func,
bool heavy_load_on_streaming_rpc,
bool remote_func_outputs = false);
#include <stddef.h>
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
extern "C" {
size_t TF_TensorHandleListSize(const TF_TensorHandleList* list) {
return tensorflow::unwrap(list)->size();
}
TFE_TensorHandle* TF_TensorHandleListGet(const TF_TensorHandleList* list,
int i) {
return tensorflow::wrap((*tensorflow::unwrap(list))[i]);
}
} // end extern "C"
#endif // TENSORFLOW_C_EAGER_C_API_REMOTE_TEST_UTIL_H_

View File

@ -30,6 +30,9 @@ using tensorflow::string;
namespace tensorflow {
namespace {
// The tests are parameterized on:
// - a string representing the tracing implementation: "mlir" or "graphdef".
// - a boolean that when true enables TFRT as the execution engine.
class UnifiedCAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool>> {
protected:
@ -983,6 +986,10 @@ TEST_P(UnifiedCAPI, TF_ExecutionContextGetTFEContextFromFunctionContextRaises) {
TF_DeleteExecutionContext(graph_ctx);
}
// The above tests are run for a combination of:
// - graphdef and MLIR tracing engine
// - Using TFRT as an execution runtime (true == enable TFRT)
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI,
::testing::Combine(::testing::Values("graphdef",

View File

@ -363,6 +363,10 @@ Status Execute(AbstractOperation* op_, AbstractContext* ctx,
input_ids[i] = ToId(forward_op_->inputs[i]);
input_dtypes[i] = forward_op_->inputs[i]->DataType();
}
for (int i = 0; i < *num_retvals; i++) {
// TODO(srbs): Manage refcount of ForwardOperation's inputs/outputs.
forward_op_->outputs.push_back(retvals[i]);
}
std::vector<TapeTensor> tape_tensors;
for (auto t : retvals) {
tape_tensors.push_back(TapeTensor(t, ctx));

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
@ -35,6 +36,8 @@ namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using std::vector;
using tracing::TracingOperation;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
@ -45,7 +48,9 @@ class CppGradients
};
Status RegisterGradients(GradientRegistry* registry) {
return registry->Register("Add", AddRegisterer);
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
return Status::OK();
}
// Computes `inputs[0] + inputs[1]` and records it on the tape.
@ -58,9 +63,9 @@ Status Add(AbstractContext* ctx, Tape* tape,
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
if (isa<tracing::TracingOperation>(add_op.get())) {
if (isa<TracingOperation>(add_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
dyn_cast<TracingOperation>(add_op.get())->SetOpName("my_add"));
}
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
@ -69,6 +74,26 @@ Status Add(AbstractContext* ctx, Tape* tape,
registry);
}
// Computes `exp(inputs[0])` and records it on the tape.
Status Exp(AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractOperationPtr exp_op(ctx->CreateOperation());
ForwardOperation forward_op;
forward_op.ctx = ctx;
TF_RETURN_IF_ERROR(
Reset(exp_op.get(), "Exp", /*raw_device_name=*/nullptr, &forward_op));
if (isa<TracingOperation>(exp_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(exp_op.get())->SetOpName("my_exp"));
}
TF_RETURN_IF_ERROR(AddInput(exp_op.get(), inputs[0], &forward_op));
int num_retvals = 1;
return Execute(exp_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
registry);
}
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
@ -101,6 +126,35 @@ Status AddGradModel(AbstractContext* ctx,
return Status::OK();
}
// Computes
// y = exp(inputs[0])
// return grad(y, {inputs[0]})
Status ExpGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
tape->Watch(ToId(inputs[0])); // Watch x.
std::vector<AbstractTensorHandle*> exp_outputs(1);
TF_RETURN_IF_ERROR(Exp(ctx, tape, inputs, absl::MakeSpan(exp_outputs),
registry)); // Compute x+y.
std::unordered_map<tensorflow::int64, TapeTensor>
source_tensors_that_are_targets;
std::vector<AbstractTensorHandle*> out_grads;
TF_RETURN_IF_ERROR(tape->ComputeGradient(
vspace, /*target_tensor_ids=*/{ToId(exp_outputs[0])},
/*source_tensor_ids=*/{ToId(inputs[0])}, source_tensors_that_are_targets,
/*output_gradients=*/{}, &out_grads));
for (auto exp_output : exp_outputs) {
exp_output->Unref();
}
outputs[0] = out_grads[0];
delete tape;
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -132,26 +186,42 @@ Status RunModel(Model model, AbstractContext* ctx,
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
std::vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
OutputList output_list;
output_list.expected_num_outputs = outputs.size();
output_list.outputs.resize(outputs.size());
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(output_list.outputs), registry));
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
output_list.outputs[0]->Unref();
output_list.outputs[1]->Unref();
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
@ -160,8 +230,19 @@ Status RunModel(Model model, AbstractContext* ctx,
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size();
TF_RETURN_IF_ERROR(fn_op->Execute(outputs, &retvals));
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
@ -264,18 +345,62 @@ TEST_P(CppGradients, TestAddGrad) {
TF_DeleteTensor(result_tensor);
}
TEST_P(CppGradients, TestExpGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Pseudo-code:
//
// tape.watch(x)
// y = exp(x)
// outputs = tape.gradient(y, x)
std::vector<AbstractTensorHandle*> outputs(1);
s = RunModel(ExpGradModel, ctx.get(), {x.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* result_tensor;
s = getValue(outputs[0], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_NEAR(*result_value, 2.718, 0.001);
outputs[0]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(true, false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef"),
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif

View File

@ -191,8 +191,8 @@ void* TF_LoadSharedLibrary(const char* library_filename, TF_Status* status) {
void* handle = nullptr;
TF_SetStatus(status, TF_OK, "");
::tensorflow::Set_TF_Status_from_Status(
status,
::tensorflow::Env::Default()->LoadLibrary(library_filename, &handle));
status, ::tensorflow::Env::Default()->LoadDynamicLibrary(library_filename,
&handle));
return handle;
}

View File

@ -35,8 +35,8 @@ using UniquePtrTo_TF_Status =
::std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
Status ModularFileSystem::NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) {
if (ops_->new_random_access_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewRandomAccessFile()"));
@ -55,8 +55,8 @@ Status ModularFileSystem::NewRandomAccessFile(
}
Status ModularFileSystem::NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_writable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewWritableFile()"));
@ -75,8 +75,8 @@ Status ModularFileSystem::NewWritableFile(
}
Status ModularFileSystem::NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>* result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) {
if (ops_->new_appendable_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support NewAppendableFile()"));
@ -95,8 +95,8 @@ Status ModularFileSystem::NewAppendableFile(
}
Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
const std::string& fname, std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token */) {
const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) {
if (ops_->new_read_only_memory_region_from_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname,
@ -116,8 +116,8 @@ Status ModularFileSystem::NewReadOnlyMemoryRegionFromFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::FileExists(
const std::string& fname /*, TransactionToken* token */) {
Status ModularFileSystem::FileExists(const std::string& fname,
TransactionToken* token) {
if (ops_->path_exists == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support FileExists()"));
@ -129,9 +129,9 @@ Status ModularFileSystem::FileExists(
return StatusFromTF_Status(plugin_status.get());
}
bool ModularFileSystem::FilesExist(
const std::vector<std::string>& files,
std::vector<Status>* status /*, TransactionToken* token */) {
bool ModularFileSystem::FilesExist(const std::vector<std::string>& files,
TransactionToken* token,
std::vector<Status>* status) {
if (ops_->paths_exist == nullptr)
return FileSystem::FilesExist(files, status);
@ -162,9 +162,9 @@ bool ModularFileSystem::FilesExist(
return result;
}
Status ModularFileSystem::GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token */) {
Status ModularFileSystem::GetChildren(const std::string& dir,
TransactionToken* token,
std::vector<std::string>* result) {
if (ops_->get_children == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dir, " does not support GetChildren()"));
@ -188,9 +188,9 @@ Status ModularFileSystem::GetChildren(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>* result /*, TransactionToken* token */) {
Status ModularFileSystem::GetMatchingPaths(const std::string& pattern,
TransactionToken* token,
std::vector<std::string>* result) {
if (ops_->get_matching_paths == nullptr)
return internal::GetMatchingPaths(this, Env::Default(), pattern, result);
@ -211,8 +211,8 @@ Status ModularFileSystem::GetMatchingPaths(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteFile(
const std::string& fname /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteFile(const std::string& fname,
TransactionToken* token) {
if (ops_->delete_file == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support DeleteFile()"));
@ -224,9 +224,10 @@ Status ModularFileSystem::DeleteFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteRecursively(const std::string& dirname,
TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) {
if (undeleted_files == nullptr || undeleted_dirs == nullptr)
return errors::FailedPrecondition(
"DeleteRecursively must not be called with `undeleted_files` or "
@ -247,8 +248,8 @@ Status ModularFileSystem::DeleteRecursively(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::DeleteDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::DeleteDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->delete_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support DeleteDir()"));
@ -260,8 +261,8 @@ Status ModularFileSystem::DeleteDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->recursively_create_dir == nullptr)
return FileSystem::RecursivelyCreateDir(dirname);
@ -272,8 +273,8 @@ Status ModularFileSystem::RecursivelyCreateDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CreateDir(
const std::string& dirname /*, TransactionToken* token */) {
Status ModularFileSystem::CreateDir(const std::string& dirname,
TransactionToken* token) {
if (ops_->create_dir == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", dirname, " does not support CreateDir()"));
@ -285,9 +286,8 @@ Status ModularFileSystem::CreateDir(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token */) {
Status ModularFileSystem::Stat(const std::string& fname,
TransactionToken* token, FileStatistics* stat) {
if (ops_->stat == nullptr)
return errors::Unimplemented(tensorflow::strings::StrCat(
"Filesystem for ", fname, " does not support Stat()"));
@ -310,8 +310,8 @@ Status ModularFileSystem::Stat(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::IsDirectory(
const std::string& name /*, TransactionToken* token */) {
Status ModularFileSystem::IsDirectory(const std::string& name,
TransactionToken* token) {
if (ops_->is_directory == nullptr) return FileSystem::IsDirectory(name);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -321,9 +321,9 @@ Status ModularFileSystem::IsDirectory(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token */) {
Status ModularFileSystem::GetFileSize(const std::string& fname,
TransactionToken* token,
uint64* file_size) {
if (ops_->get_file_size == nullptr) {
FileStatistics stat;
Status status = Stat(fname, &stat);
@ -342,9 +342,9 @@ Status ModularFileSystem::GetFileSize(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
Status ModularFileSystem::RenameFile(const std::string& src,
const std::string& target,
TransactionToken* token) {
if (ops_->rename_file == nullptr) {
Status status = CopyFile(src, target);
if (status.ok()) status = DeleteFile(src);
@ -359,9 +359,9 @@ Status ModularFileSystem::RenameFile(
return StatusFromTF_Status(plugin_status.get());
}
Status ModularFileSystem::CopyFile(
const std::string& src,
const std::string& target /*, TransactionToken* token */) {
Status ModularFileSystem::CopyFile(const std::string& src,
const std::string& target,
TransactionToken* token) {
if (ops_->copy_file == nullptr) return FileSystem::CopyFile(src, target);
UniquePtrTo_TF_Status plugin_status(TF_NewStatus(), TF_DeleteStatus);
@ -372,8 +372,7 @@ Status ModularFileSystem::CopyFile(
return StatusFromTF_Status(plugin_status.get());
}
std::string ModularFileSystem::TranslateName(
const std::string& name /*, TransactionToken* token */) const {
std::string ModularFileSystem::TranslateName(const std::string& name) const {
if (ops_->translate_name == nullptr) return FileSystem::TranslateName(name);
char* p = ops_->translate_name(filesystem_.get(), name.c_str());
@ -385,7 +384,7 @@ std::string ModularFileSystem::TranslateName(
return ret;
}
void ModularFileSystem::FlushCaches(/*TransactionToken* token*/) {
void ModularFileSystem::FlushCaches(TransactionToken* token) {
if (ops_->flush_caches != nullptr) ops_->flush_caches(filesystem_.get());
}
@ -462,7 +461,7 @@ Status RegisterFilesystemPlugin(const std::string& dso_path) {
// Step 1: Load plugin
Env* env = Env::Default();
void* dso_handle;
TF_RETURN_IF_ERROR(env->LoadLibrary(dso_path.c_str(), &dso_handle));
TF_RETURN_IF_ERROR(env->LoadDynamicLibrary(dso_path.c_str(), &dso_handle));
// Step 2: Load symbol for `TF_InitPlugin`
void* dso_symbol;

View File

@ -59,71 +59,48 @@ class ModularFileSystem final : public FileSystem {
~ModularFileSystem() override { ops_->cleanup(filesystem_.get()); }
TF_USE_FILESYSTEM_METHODS_WITH_NO_TRANSACTION_SUPPORT;
Status NewRandomAccessFile(
const std::string& fname,
std::unique_ptr<RandomAccessFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewWritableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
Status NewAppendableFile(
const std::string& fname,
std::unique_ptr<WritableFile>*
result /*, TransactionToken* token = nullptr */) override;
const std::string& fname, TransactionToken* token,
std::unique_ptr<RandomAccessFile>* result) override;
Status NewWritableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override;
Status NewAppendableFile(const std::string& fname, TransactionToken* token,
std::unique_ptr<WritableFile>* result) override;
Status NewReadOnlyMemoryRegionFromFile(
const std::string& fname,
std::unique_ptr<ReadOnlyMemoryRegion>*
result /*, TransactionToken* token = nullptr */) override;
Status FileExists(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
const std::string& fname, TransactionToken* token,
std::unique_ptr<ReadOnlyMemoryRegion>* result) override;
Status FileExists(const std::string& fname, TransactionToken* token) override;
bool FilesExist(const std::vector<std::string>& files,
std::vector<Status>*
status /*, TransactionToken* token = nullptr */) override;
Status GetChildren(
const std::string& dir,
std::vector<std::string>* result /*, TransactionToken* token = nullptr */)
override;
Status GetMatchingPaths(
const std::string& pattern,
std::vector<std::string>*
results /*, TransactionToken* token = nullptr */) override;
Status DeleteFile(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status DeleteRecursively(
const std::string& dirname, int64* undeleted_files,
int64* undeleted_dirs /*, TransactionToken* token = nullptr */) override;
Status DeleteDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status RecursivelyCreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status CreateDir(
const std::string& dirname /*, TransactionToken* token = nullptr */)
override;
Status Stat(
const std::string& fname,
FileStatistics* stat /*, TransactionToken* token = nullptr */) override;
Status IsDirectory(
const std::string& fname /*, TransactionToken* token = nullptr */)
override;
Status GetFileSize(
const std::string& fname,
uint64* file_size /*, TransactionToken* token = nullptr */) override;
Status RenameFile(
const std::string& src,
const std::string& target /*, TransactionToken* token = nullptr */)
override;
Status CopyFile(const std::string& src,
const std::string&
target /*, TransactionToken* token = nullptr */) override;
std::string TranslateName(
const std::string& name /*, TransactionToken* token = nullptr */)
const override;
void FlushCaches(/* TransactionToken* token=nullptr */) override;
TransactionToken* token,
std::vector<Status>* status) override;
Status GetChildren(const std::string& dir, TransactionToken* token,
std::vector<std::string>* result) override;
Status GetMatchingPaths(const std::string& pattern, TransactionToken* token,
std::vector<std::string>* results) override;
Status DeleteFile(const std::string& fname, TransactionToken* token) override;
Status DeleteRecursively(const std::string& dirname, TransactionToken* token,
int64* undeleted_files,
int64* undeleted_dirs) override;
Status DeleteDir(const std::string& dirname,
TransactionToken* token) override;
Status RecursivelyCreateDir(const std::string& dirname,
TransactionToken* token) override;
Status CreateDir(const std::string& dirname,
TransactionToken* token) override;
Status Stat(const std::string& fname, TransactionToken* token,
FileStatistics* stat) override;
Status IsDirectory(const std::string& fname,
TransactionToken* token) override;
Status GetFileSize(const std::string& fname, TransactionToken* token,
uint64* file_size) override;
Status RenameFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
Status CopyFile(const std::string& src, const std::string& target,
TransactionToken* token) override;
std::string TranslateName(const std::string& name) const override;
void FlushCaches(TransactionToken* token) override;
private:
std::unique_ptr<TF_Filesystem> filesystem_;

View File

@ -33,7 +33,6 @@ limitations under the License.
// Windows defines the following macros to convert foo to fooA or fooW,
// depending on the type of the string argument. We don't use these macros, so
// undefine them here.
#undef LoadLibrary
#undef CopyFile
#undef DeleteFile
#undef TranslateName

View File

@ -33,6 +33,7 @@ cc_library(
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
@ -663,28 +664,179 @@ void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
}
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
static void StatForObject(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
GcsFileStat* stat, TF_Status* status) {
if (object.empty())
return TF_SetStatus(
status, TF_INVALID_ARGUMENT,
("'object' must be a non-empty string. (File: " + path + ")").c_str());
TF_SetStatus(status, TF_OK, "");
gcs_file->stat_cache->LookupOrCompute(
path, stat,
[gcs_file, bucket, object](const std::string& path, GcsFileStat* stat,
TF_Status* status) {
UncachedStatForObject(bucket, object, stat, &gcs_file->gcs_client,
status);
},
status);
}
static bool ObjectExists(GCSFile* gcs_file, const std::string& path,
const std::string& bucket, const std::string& object,
TF_Status* status) {
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return !stat.base.is_directory;
}
static bool BucketExists(GCSFile* gcs_file, const std::string& bucket,
TF_Status* status) {
auto metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_NOT_FOUND)
return false;
if (TF_GetCode(status) == TF_NOT_FOUND) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static std::vector<std::string> GetChildrenBounded(
GCSFile* gcs_file, std::string dir, uint64_t max_results, bool recursive,
bool include_self_directory_marker, TF_Status* status) {
std::string bucket, prefix;
MaybeAppendSlash(&dir);
ParseGCSPath(dir, true, &bucket, &prefix, status);
std::vector<std::string> result;
uint64_t count = 0;
std::string delimiter = recursive ? "" : "/";
for (auto&& item : gcs_file->gcs_client.ListObjectsAndPrefixes(
bucket, gcs::Prefix(prefix), gcs::Delimiter(delimiter))) {
if (count == max_results) {
TF_SetStatus(status, TF_OK, "");
return result;
}
if (!item) {
TF_SetStatusFromGCSStatus(item.status(), status);
return result;
}
auto value = *std::move(item);
std::string children = absl::holds_alternative<std::string>(value)
? absl::get<std::string>(value)
: absl::get<gcs::ObjectMetadata>(value).name();
auto pos = children.find(prefix);
if (pos != 0) {
TF_SetStatus(status, TF_INTERNAL,
("Unexpected response: the returned file name " + children +
" doesn't match the prefix " + prefix)
.c_str());
return result;
}
children.erase(0, prefix.length());
if (!children.empty() || include_self_directory_marker) {
result.emplace_back(children);
}
++count;
}
return result;
}
static bool FolderExists(GCSFile* gcs_file, std::string dir,
TF_Status* status) {
ExpiringLRUCache<GcsFileStat>::ComputeFunc compute_func =
[gcs_file](const std::string& dir, GcsFileStat* stat, TF_Status* status) {
auto children =
GetChildrenBounded(gcs_file, dir, 1, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
if (!children.empty()) {
stat->base = {0, 0, true};
return TF_SetStatus(status, TF_OK, "");
} else {
return TF_SetStatus(status, TF_INVALID_ARGUMENT, "Not a directory!");
}
};
GcsFileStat stat;
MaybeAppendSlash(&dir);
gcs_file->stat_cache->LookupOrCompute(dir, &stat, compute_func, status);
if (TF_GetCode(status) != TF_OK && TF_GetCode(status) != TF_INVALID_ARGUMENT)
return false;
if (TF_GetCode(status) == TF_INVALID_ARGUMENT) {
TF_SetStatus(status, TF_OK, "");
return false;
}
return true;
}
static void ClearFileCaches(GCSFile* gcs_file, const std::string& path) {
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->RemoveFile(path);
gcs_file->stat_cache->Delete(path);
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
bool result = BucketExists(gcs_file, bucket, status);
if (result) return TF_SetStatus(status, TF_OK, "");
}
GcsFileStat stat;
StatForObject(gcs_file, path, bucket, object, &stat, status);
if (TF_GetCode(status) != TF_NOT_FOUND) return;
bool result = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK || (TF_GetCode(status) == TF_OK && result))
return;
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string dir = path;
MaybeAppendSlash(&dir);
std::string bucket, object;
ParseGCSPath(dir, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
bool is_directory = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return;
if (!is_directory)
TF_SetStatus(status, TF_NOT_FOUND,
("The specified bucket " + dir + " was not found.").c_str());
return;
}
MaybeAppendSlash(&object);
auto object_metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
TF_SetStatusFromGCSStatus(object_metadata.status(), status);
if (TF_GetCode(status) == TF_NOT_FOUND) {
auto insert_metadata =
gcs_file->gcs_client.InsertObject(bucket, object, "");
TF_SetStatusFromGCSStatus(insert_metadata.status(), status);
} else if (TF_GetCode(status) == TF_OK) {
PathExists(filesystem, dir.c_str(), status);
if (TF_GetCode(status) == TF_OK)
return TF_SetStatus(status, TF_ALREADY_EXISTS, path);
auto metadata = gcs_file->gcs_client.InsertObject(
bucket, object, "",
// Adding this parameter means HTTP_CODE_PRECONDITION_FAILED
// will be returned if the object already exists, so avoid reuploading.
gcs::IfGenerationMatch(0));
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
TF_SetStatus(status, TF_ALREADY_EXISTS, path);
}
}
// TODO(vnvo2409): `RecursivelyCreateDir` should use `CreateDir` instead of the
@ -700,79 +852,31 @@ void DeleteFile(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) == TF_OK) ClearFileCaches(gcs_file, path);
}
// Checks that the directory is empty (i.e no objects with this prefix exist).
// Deletes the GCS directory marker if it exists.
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
MaybeAppendSlash(&object);
// A directory is considered empty either if there are no matching objects
// with the corresponding name prefix or if there is exactly one matching
// object and it is the directory marker. Therefore we need to retrieve
// at most two children for the prefix to detect if a directory is empty.
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
int object_count = 0;
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
++object_count;
// We consider a path is a non-empty directory in two cases:
// - There are more than two objects whose keys start with the name of this
// directory.
// - There is one object whose key contains the name of this directory ( but
// not equal ).
if (object_count > 1 || metadata->name() != object) {
TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
return;
}
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
}
// TODO(vnvo2409): `DeleteRecursively` needs `GetChildrens` but there will be
// some differents compared to the default implementation. Will be refactored.
static void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files,
uint64_t* undeleted_dirs, TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
auto childrens = GetChildrenBounded(gcs_file, path, 2, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto gcs_status = gcs::DeleteByPrefix(gcs_file->gcs_client, bucket, object);
TF_SetStatusFromGCSStatus(gcs_status, status);
if (TF_GetCode(status) != TF_OK) return;
*undeleted_dirs = 0;
*undeleted_files = 0;
}
// TODO(vnvo2409): `RewriteObjectBlocking` will set `status` to `TF_NOT_FOUND`
// if the object does not exist. In that case, we will have to check if the
// `src` is a directory or not to set the correspondent `status` (i.e
// `TF_NOT_FOUND` if path `src` does not exist, `TF_FAILED_PRECONDITION` if
// path `src` is a directory).
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (childrens.size() > 1 || (childrens.size() == 1 && !childrens[0].empty()))
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (childrens.size() == 1 && childrens[0].empty()) {
// This is the directory marker object. Delete it.
std::string dir = path;
MaybeAppendSlash(&dir);
DeleteFile(filesystem, dir.c_str(), status);
return;
}
auto gcs_status = gcs_file->gcs_client.DeleteObject(bucket_src, object_src);
TF_SetStatusFromGCSStatus(gcs_status, status);
TF_SetStatus(status, TF_OK, "");
}
void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
@ -791,31 +895,6 @@ void CopyFile(const TF_Filesystem* filesystem, const char* src, const char* dst,
TF_SetStatusFromGCSStatus(metadata.status(), status);
}
// TODO(vnvo2409): This approach can cause a problem when our path is
// `path/to/dir` and there is an object with key `path/to/directory`. Will be
// fixed when refactoring.
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
ParseGCSPath(path, true, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return;
}
// We consider a path exists if there is at least one object whose key
// contains the path.
return TF_SetStatus(status, TF_OK, "");
}
return TF_SetStatus(
status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
}
bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
std::string bucket, object;
@ -824,41 +903,127 @@ bool IsDirectory(const TF_Filesystem* filesystem, const char* path,
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
if (object.empty()) {
auto bucket_metadata = gcs_file->gcs_client.GetBucketMetadata(bucket);
TF_SetStatusFromGCSStatus(bucket_metadata.status(), status);
if (TF_GetCode(status) == TF_OK)
return true;
else
return false;
bool result = BucketExists(gcs_file, bucket, status);
if (TF_GetCode(status) != TF_OK) return false;
if (!result)
TF_SetStatus(
status, TF_NOT_FOUND,
("The specified bucket gs://" + bucket + " was not found.").c_str());
return result;
}
// We check if there is an object with this key on the GCS server.
auto metadata = gcs_file->gcs_client.GetObjectMetadata(bucket, object);
if (metadata) {
TF_SetStatus(status, TF_OK, "");
if (metadata->name().back() == '/')
return true;
else
return false;
}
bool is_folder = FolderExists(gcs_file, path, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_folder) return true;
// If there is no object with this key on the GCS server. We check if there is
// any object whose key contains that path.
MaybeAppendSlash(&object);
for (auto&& metadata :
gcs_file->gcs_client.ListObjects(bucket, gcs::Prefix(object))) {
if (!metadata) {
TF_SetStatusFromGCSStatus(metadata.status(), status);
return false;
}
TF_SetStatus(status, TF_OK, "");
return true;
bool is_object = ObjectExists(gcs_file, path, bucket, object, status);
if (TF_GetCode(status) != TF_OK) return false;
if (is_object) {
TF_SetStatus(
status, TF_FAILED_PRECONDITION,
absl::StrCat("The specified path ", path, " is not a directory.")
.c_str());
return false;
}
TF_SetStatus(status, TF_NOT_FOUND,
absl::StrCat("The path ", path, " does not exist.").c_str());
return false;
}
static void RenameObject(const TF_Filesystem* filesystem,
const std::string& src, const std::string& dst,
TF_Status* status) {
std::string bucket_src, object_src;
ParseGCSPath(src, false, &bucket_src, &object_src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string bucket_dst, object_dst;
ParseGCSPath(dst, false, &bucket_dst, &object_dst, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
auto metadata = gcs_file->gcs_client.RewriteObjectBlocking(
bucket_src, object_src, bucket_dst, object_dst);
TF_SetStatusFromGCSStatus(metadata.status(), status);
if (TF_GetCode(status) != TF_OK) return;
ClearFileCaches(gcs_file, dst);
DeleteFile(filesystem, src.c_str(), status);
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
if (!IsDirectory(filesystem, src, status)) {
if (TF_GetCode(status) == TF_FAILED_PRECONDITION)
RenameObject(filesystem, src, dst, status);
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, src, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string src_dir = src;
std::string dst_dir = dst;
MaybeAppendSlash(&src_dir);
MaybeAppendSlash(&dst_dir);
for (const std::string& children : childrens) {
RenameObject(filesystem, src_dir + children, dst_dir + children, status);
if (TF_GetCode(status) != TF_OK) return;
}
TF_SetStatus(status, TF_OK, "");
}
void DeleteRecursively(const TF_Filesystem* filesystem, const char* path,
uint64_t* undeleted_files, uint64_t* undeleted_dirs,
TF_Status* status) {
if (!undeleted_files || !undeleted_dirs)
return TF_SetStatus(
status, TF_INTERNAL,
"'undeleted_files' and 'undeleted_dirs' cannot be nullptr.");
*undeleted_files = 0;
*undeleted_dirs = 0;
if (!IsDirectory(filesystem, path, status)) {
*undeleted_dirs = 1;
return;
}
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, true, true, status);
if (TF_GetCode(status) != TF_OK) return;
std::string dir = path;
MaybeAppendSlash(&dir);
for (const std::string& children : childrens) {
const std::string& full_path = dir + children;
DeleteFile(filesystem, full_path.c_str(), status);
if (TF_GetCode(status) != TF_OK) {
if (IsDirectory(filesystem, full_path.c_str(), status))
// The object is a directory marker.
(*undeleted_dirs)++;
else
(*undeleted_files)++;
}
}
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
std::vector<std::string> childrens =
GetChildrenBounded(gcs_file, path, UINT64_MAX, false, false, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = childrens.size();
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
for (int i = 0; i < num_entries; i++)
(*entries)[i] = strdup(childrens[i].c_str());
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
std::string bucket, object;
@ -896,6 +1061,17 @@ void Stat(const TF_Filesystem* filesystem, const char* path,
}
}
static char* TranslateName(const TF_Filesystem* filesystem, const char* uri) {
return strdup(uri);
}
static void FlushCaches(const TF_Filesystem* filesystem) {
auto gcs_file = static_cast<GCSFile*>(filesystem->plugin_filesystem);
absl::ReaderMutexLock l(&gcs_file->block_cache_lock);
gcs_file->file_block_cache->Flush();
gcs_file->stat_cache->Clear();
}
} // namespace tf_gcs_filesystem
static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
@ -912,6 +1088,13 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
plugin_memory_allocate(TF_WRITABLE_FILE_OPS_SIZE));
ops->writable_file_ops->cleanup = tf_writable_file::Cleanup;
ops->read_only_memory_region_ops = static_cast<TF_ReadOnlyMemoryRegionOps*>(
plugin_memory_allocate(TF_READ_ONLY_MEMORY_REGION_OPS_SIZE));
ops->read_only_memory_region_ops->cleanup =
tf_read_only_memory_region::Cleanup;
ops->read_only_memory_region_ops->data = tf_read_only_memory_region::Data;
ops->read_only_memory_region_ops->length = tf_read_only_memory_region::Length;
ops->filesystem_ops = static_cast<TF_FilesystemOps*>(
plugin_memory_allocate(TF_FILESYSTEM_OPS_SIZE));
ops->filesystem_ops->init = tf_gcs_filesystem::Init;
@ -921,6 +1104,20 @@ static void ProvideFilesystemSupportFor(TF_FilesystemPluginOps* ops,
ops->filesystem_ops->new_writable_file = tf_gcs_filesystem::NewWritableFile;
ops->filesystem_ops->new_appendable_file =
tf_gcs_filesystem::NewAppendableFile;
ops->filesystem_ops->new_read_only_memory_region_from_file =
tf_gcs_filesystem::NewReadOnlyMemoryRegionFromFile;
ops->filesystem_ops->create_dir = tf_gcs_filesystem::CreateDir;
ops->filesystem_ops->delete_file = tf_gcs_filesystem::DeleteFile;
ops->filesystem_ops->delete_dir = tf_gcs_filesystem::DeleteDir;
ops->filesystem_ops->delete_recursively =
tf_gcs_filesystem::DeleteRecursively;
ops->filesystem_ops->copy_file = tf_gcs_filesystem::CopyFile;
ops->filesystem_ops->path_exists = tf_gcs_filesystem::PathExists;
ops->filesystem_ops->is_directory = tf_gcs_filesystem::IsDirectory;
ops->filesystem_ops->stat = tf_gcs_filesystem::Stat;
ops->filesystem_ops->get_children = tf_gcs_filesystem::GetChildren;
ops->filesystem_ops->translate_name = tf_gcs_filesystem::TranslateName;
ops->filesystem_ops->flush_caches = tf_gcs_filesystem::FlushCaches;
}
void TF_InitPlugin(TF_FilesystemPluginInfo* info) {

View File

@ -442,6 +442,202 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_SetStatus(status, TF_OK, "");
}
void NewReadOnlyMemoryRegionFromFile(const TF_Filesystem* filesystem,
const char* path,
TF_ReadOnlyMemoryRegion* region,
TF_Status* status) {
// hadoopReadZero() technically supports this call with the following
// caveats:
// - It only works up to 2 GB. We'd have to Stat() the file to ensure that
// it fits.
// - If not on the local filesystem, the entire file will be read, making
// it inefficient for callers that assume typical mmap() behavior.
TF_SetStatus(status, TF_UNIMPLEMENTED,
"HDFS does not support ReadOnlyMemoryRegion");
}
void PathExists(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsExists(fs, hdfs_path.c_str()) == 0)
TF_SetStatus(status, TF_OK, "");
else
TF_SetStatus(status, TF_NOT_FOUND,
(std::string(path) + " not found").c_str());
}
void Stat(const TF_Filesystem* filesystem, const char* path,
TF_FileStatistics* stats, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) return TF_SetStatusFromIOError(status, errno, path);
stats->length = static_cast<int64_t>(info->mSize);
stats->mtime_nsec = static_cast<int64_t>(info->mLastMod) * 1e9;
stats->is_directory = info->mKind == kObjectKindDirectory;
libhdfs->hdfsFreeFileInfo(info, 1);
TF_SetStatus(status, TF_OK, "");
}
int64_t GetFileSize(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
auto info = libhdfs->hdfsGetPathInfo(fs, hdfs_path.c_str());
if (info == nullptr) {
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
TF_SetStatus(status, TF_OK, "");
auto size = static_cast<int64_t>(info->mSize);
libhdfs->hdfsFreeFileInfo(info, 1);
return size;
}
void DeleteFile(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/0) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void CreateDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
if (libhdfs->hdfsCreateDirectory(fs, hdfs_path.c_str()) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void DeleteDir(const TF_Filesystem* filesystem, const char* path,
TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// Count the number of entries in the directory, and only delete if it's
// non-empty. This is consistent with the interface, but note that there's
// a race condition where a file may be added after this check, in which
// case the directory will still be deleted.
int entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &entries);
if (info != nullptr) libhdfs->hdfsFreeFileInfo(info, entries);
// Due to HDFS bug HDFS-8407, we can't distinguish between an error and empty
// folder, especially for Kerberos enable setup, EAGAIN is quite common when
// the call is actually successful. Check again by Stat.
if (info == nullptr && errno != 0) {
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return;
}
if (entries > 0)
return TF_SetStatus(status, TF_FAILED_PRECONDITION,
"Cannot delete a non-empty directory.");
if (libhdfs->hdfsDelete(fs, hdfs_path.c_str(), /*recursive=*/1) != 0)
TF_SetStatusFromIOError(status, errno, path);
else
TF_SetStatus(status, TF_OK, "");
}
void RenameFile(const TF_Filesystem* filesystem, const char* src,
const char* dst, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, src, status);
if (TF_GetCode(status) != TF_OK) return;
std::string scheme, namenode, hdfs_path_src, hdfs_path_dst;
ParseHadoopPath(src, &scheme, &namenode, &hdfs_path_src);
ParseHadoopPath(dst, &scheme, &namenode, &hdfs_path_dst);
if (libhdfs->hdfsExists(fs, hdfs_path_dst.c_str()) == 0 &&
libhdfs->hdfsDelete(fs, hdfs_path_dst.c_str(), /*recursive=*/0) != 0)
return TF_SetStatusFromIOError(status, errno, dst);
if (libhdfs->hdfsRename(fs, hdfs_path_src.c_str(), hdfs_path_dst.c_str()) !=
0)
TF_SetStatusFromIOError(status, errno, src);
else
TF_SetStatus(status, TF_OK, "");
}
int GetChildren(const TF_Filesystem* filesystem, const char* path,
char*** entries, TF_Status* status) {
auto libhdfs = static_cast<LibHDFS*>(filesystem->plugin_filesystem);
auto fs = Connect(libhdfs, path, status);
if (TF_GetCode(status) != TF_OK) return -1;
std::string scheme, namenode, hdfs_path;
ParseHadoopPath(path, &scheme, &namenode, &hdfs_path);
// hdfsListDirectory returns nullptr if the directory is empty. Do a separate
// check to verify the directory exists first.
TF_FileStatistics stat;
Stat(filesystem, path, &stat, status);
if (TF_GetCode(status) != TF_OK) return -1;
int num_entries = 0;
auto info = libhdfs->hdfsListDirectory(fs, hdfs_path.c_str(), &num_entries);
if (info == nullptr) {
if (stat.is_directory) {
// Assume it's an empty directory.
TF_SetStatus(status, TF_OK, "");
return 0;
}
TF_SetStatusFromIOError(status, errno, path);
return -1;
}
*entries = static_cast<char**>(
plugin_memory_allocate(num_entries * sizeof((*entries)[0])));
auto BaseName = [](const std::string& name) {
return name.substr(name.find_last_of('/') + 1);
};
for (int i = 0; i < num_entries; i++) {
(*entries)[i] = strdup(BaseName(info[i].mName).c_str());
}
libhdfs->hdfsFreeFileInfo(info, num_entries);
TF_SetStatus(status, TF_OK, "");
return num_entries;
}
// TODO(vnvo2409): Implement later
} // namespace tf_hadoop_filesystem

View File

@ -18,6 +18,7 @@ cc_library(
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/c/eager:gradients",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -14,9 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::Mul;
namespace tensorflow {
namespace gradients {
@ -26,9 +31,9 @@ class AddGradientFunction : public GradientFunction {
public:
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
std::vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*>* grad_outputs) override {
grad_outputs->resize(2);
std::vector<AbstractTensorHandle*> identity_outputs(1);
vector<AbstractTensorHandle*> identity_outputs(1);
// TODO(b/145674566): Handle name unification in tracing code.
// TODO(b/161805092): Support broadcasting.
TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
@ -44,10 +49,38 @@ class AddGradientFunction : public GradientFunction {
~AddGradientFunction() override {}
};
class ExpGradientFunction : public GradientFunction {
public:
explicit ExpGradientFunction(AbstractTensorHandle* exp) : exp_(exp) {
exp->Ref();
}
Status Compute(Context* ctx,
absl::Span<AbstractTensorHandle* const> grad_inputs,
vector<AbstractTensorHandle*>* grad_outputs) override {
vector<AbstractTensorHandle*> conj_outputs(1);
TF_RETURN_IF_ERROR(
Conj(ctx->ctx, {exp_.get()}, absl::MakeSpan(conj_outputs), "ExpConj"));
AbstractTensorHandlePtr conj_output_releaser(conj_outputs[0]);
grad_outputs->resize(1);
TF_RETURN_IF_ERROR(Mul(ctx->ctx, {conj_outputs[0], grad_inputs[0]},
absl::MakeSpan(*grad_outputs), "ExpGradMul"));
return Status::OK();
}
~ExpGradientFunction() override {}
private:
AbstractTensorHandlePtr exp_;
};
} // namespace
GradientFunction* AddRegisterer(const ForwardOperation& op) {
return new AddGradientFunction;
}
GradientFunction* ExpRegisterer(const ForwardOperation& op) {
return new ExpGradientFunction(op.outputs[0]);
}
} // namespace gradients
} // namespace tensorflow

View File

@ -20,6 +20,7 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
GradientFunction* AddRegisterer(const ForwardOperation& op);
GradientFunction* ExpRegisterer(const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
@ -22,3 +23,26 @@ cc_library(
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "math_ops",
srcs = [
"math_ops.cc",
],
hdrs = [
"math_ops.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":array_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {

View File

@ -15,9 +15,9 @@ limitations under the License.
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_ARRAY_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_operation.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {

View File

@ -0,0 +1,55 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace ops {
using tensorflow::tracing::TracingOperation;
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr mul_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
if (isa<TracingOperation>(mul_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<TracingOperation>(mul_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
int num_retvals = 1;
return mul_op->Execute(outputs, &num_retvals);
}
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
auto dtype = inputs[0]->DataType();
if (DataTypeIsFloating(BaseType(dtype)) ||
DataTypeIsInteger(BaseType(dtype))) {
TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
} else {
return errors::Unimplemented("Conj does not support complex types yet.");
}
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -0,0 +1,31 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#define TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
namespace tensorflow {
namespace ops {
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_OPS_MATH_OPS_H_

View File

@ -216,6 +216,23 @@ tf_cc_test(
],
)
tf_cc_test(
name = "signature_flattening_test",
srcs = [
"signature_flattening_test.cc",
],
deps = [
":saved_model_utils",
"//tensorflow/c/experimental/saved_model/core:tf_concrete_function_test_protos",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/common_runtime/eager:core",
],
)
tf_cc_test(
name = "tf_concrete_function_loading_test",
srcs = [

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
@ -36,52 +37,8 @@ namespace tensorflow {
namespace internal {
namespace {
// This returns the size of `tf.nest.flatten(value)`, on values that are
// used in tf.function's input_signatures.
int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (value.kind_case()) {
case StructuredValue::kDictValue: {
const DictValue& dict = value.dict_value();
int size = 0;
for (const auto& field : dict.fields()) {
size += FlattenedSize(field.second, status);
}
return size;
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = value.tuple_value();
int size = 0;
for (const StructuredValue& value : tuple.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kListValue: {
const ListValue& list = value.list_value();
int size = 0;
for (const StructuredValue& value : list.values()) {
size += FlattenedSize(value, status);
}
return size;
}
case StructuredValue::kTensorSpecValue: {
return 1;
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return 0;
}
default: {
status->Update(errors::Internal("Unhandled structured value kind ",
value.kind_case()));
return 0;
}
}
}
using StructuredValueDictEntry =
protobuf::MapPair<std::string, StructuredValue>;
// Perform some basic sanity checks on SavedConcreteFunction's input and
// output signatures with respect to the corresponding FunctionDef's input
@ -111,34 +68,34 @@ Status ValidateSavedFunctionCompatibleWithFunctionDef(
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979
const std::string& name = function_def->signature().name();
const StructuredValue& input_signature =
saved_concrete_function.canonicalized_input_signature();
Status status;
int input_signature_size = FlattenedSize(input_signature, &status);
TF_RETURN_IF_ERROR(status);
if (input_signature_size + saved_concrete_function.bound_inputs_size() !=
std::vector<const TensorSpecProto*> input_specs;
TF_RETURN_IF_ERROR(FlattenSignature(input_signature, &input_specs));
if (input_specs.size() + saved_concrete_function.bound_inputs_size() !=
function_def->signature().input_arg_size()) {
return errors::FailedPrecondition(
"FunctionDef ", name, " has ",
function_def->signature().input_arg_size(),
" inputs, but the SavedConcreteFunction has ", input_signature_size,
" inputs, but the SavedConcreteFunction has ", input_specs.size(),
" flattened user inputs and ",
saved_concrete_function.bound_inputs_size(), " captured inputs.");
}
const StructuredValue& output_signature =
saved_concrete_function.output_signature();
int output_signature_size = FlattenedSize(output_signature, &status);
TF_RETURN_IF_ERROR(status);
if (output_signature_size != function_def->signature().output_arg_size()) {
std::vector<const TensorSpecProto*> output_specs;
TF_RETURN_IF_ERROR(FlattenSignature(output_signature, &output_specs));
if (output_specs.size() != function_def->signature().output_arg_size()) {
return errors::FailedPrecondition(
"FunctionDef ", name, " has ",
function_def->signature().output_arg_size(),
" outputs, but the SavedConcreteFunction has ", output_signature_size,
" outputs, but the SavedConcreteFunction has ", output_specs.size(),
" flattened outputs.");
}
return status;
return Status();
}
} // namespace
@ -197,6 +154,62 @@ Status LoadTFConcreteFunction(
out);
}
Status FlattenSignature(const StructuredValue& signature,
std::vector<const TensorSpecProto*>* flattened_specs) {
// This follows the logic from
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775
switch (signature.kind_case()) {
case StructuredValue::kDictValue: {
// Dictionaries must be sorted in order of keys
const DictValue& dict = signature.dict_value();
std::vector<const StructuredValueDictEntry*> entries;
entries.reserve(dict.fields_size());
for (const auto& field : dict.fields()) {
entries.push_back(&field);
}
std::sort(entries.begin(), entries.end(),
[](const StructuredValueDictEntry* x,
const StructuredValueDictEntry* y) {
return x->first < y->first;
});
for (const auto& entry : entries) {
TF_RETURN_IF_ERROR(FlattenSignature(entry->second, flattened_specs));
}
return Status();
}
case StructuredValue::kTupleValue: {
const TupleValue& tuple = signature.tuple_value();
for (const StructuredValue& value : tuple.values()) {
TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
}
return Status();
}
case StructuredValue::kListValue: {
const ListValue& list = signature.list_value();
for (const StructuredValue& value : list.values()) {
TF_RETURN_IF_ERROR(FlattenSignature(value, flattened_specs));
}
return Status();
}
case StructuredValue::kTensorSpecValue: {
flattened_specs->push_back(&signature.tensor_spec_value());
return Status();
}
case StructuredValue::kNoneValue: {
// Base case: do nothing.
// This arises, for example, as the top-level object of an output
// signature when there are no return values.
return Status();
}
default: {
return errors::Internal("Unhandled structured value kind ",
signature.kind_case());
}
}
}
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph) {
const auto& nodes = object_graph.nodes();

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/stringpiece.h"
#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace internal {
@ -59,10 +60,17 @@ Status LoadTFConcreteFunction(
captured_objects,
ImmediateExecutionContext* ctx, std::unique_ptr<TFConcreteFunction>* out);
// Find the SavedObject in `object_graph` at location `path`. `path` must be a
// dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph` outlives
// the returned pointer.
// Flattens `signature` into a vector of TensorSpecProto pointers back into
// `signature`. `signature` must outlive flattened_specs. `signature` must also
// be the input or output signature of a SavedConcreteFunction (i.e. "nested
// structures of tensorspecs").
Status FlattenSignature(const StructuredValue& signature,
std::vector<const TensorSpecProto*>* flattened_specs);
// Find the SavedObject in `object_graph` at location `path`. `path` must be
// a dot-delimited string of object names relative to the root object. If no
// object is found, returns nullptr. Callers must ensure `object_graph`
// outlives the returned pointer.
const SavedObject* FindNodeAtPath(StringPiece path,
const SavedObjectGraph& object_graph);

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <vector>
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/struct.pb.h"
namespace tensorflow {
namespace {
// Validates names, shapes, and dtypes of two tensorspecprotos are equivalent.
bool TensorSpecsAreEqual(const TensorSpecProto& spec,
const std::string& expected_name,
const PartialTensorShape& expected_shape,
DataType expected_dtype) {
return spec.name() == expected_name &&
PartialTensorShape(spec.shape()).IsIdenticalTo(expected_shape) &&
spec.dtype() == expected_dtype;
}
// This tests the common case for a tf.function w/o inputs. This ends up
// being serialized as a tuple of an empty tuple + empty dictionary
// (corresponding to the args, kwargs) of the function.
TEST(SignatureFlatteningTest, ZeroArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ZeroArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 0);
}
// This tests the common case for a tf.function w/o outputs. This ends up
// being serialized as a "NoneValue".
TEST(SignatureFlatteningTest, ZeroRetOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ZeroReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 0);
}
TEST(SignatureFlatteningTest, SingleArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::SingleArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 1);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "x",
/* expected_shape = */ {1, 10},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
}
TEST(SignatureFlatteningTest, SingleReturnOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::SingleReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 1);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
}
TEST(SignatureFlatteningTest, ThreeArgInputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ThreeArgInputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 3);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "x",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1],
/* expected_name = */ "y",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[1]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2],
/* expected_name = */ "z",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[2]->DebugString();
}
// This test has an exotic outputsignature of tuple of a
// dictionary<string,tensor>, tensor
TEST(SignatureFlatteningTest, ThreeReturnOutputSignature) {
std::vector<const TensorSpecProto*> flattened;
StructuredValue value = testing::ThreeReturnOutputSignature();
TF_EXPECT_OK(internal::FlattenSignature(value, &flattened));
EXPECT_EQ(flattened.size(), 3);
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[0],
/* expected_name = */ "0/a",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[0]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[1],
/* expected_name = */ "0/b",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[1]->DebugString();
EXPECT_TRUE(TensorSpecsAreEqual(*flattened[2],
/* expected_name = */ "1",
/* expected_shape = */ {1},
/* expected_dtype = */ DT_FLOAT))
<< "Expected " << flattened[2]->DebugString();
}
} // namespace
} // namespace tensorflow

View File

@ -47,6 +47,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/path.h"
#include "tensorflow/core/platform/stringpiece.h"
@ -241,8 +242,11 @@ Status RestoreCheckpoint(SavedModelV2Bundle* bundle,
// TODO(bmzhao): This requires using the newly added Save/Restore
// functions from
// https://github.com/tensorflow/tensorflow/commit/df6b21c13c82b5d0981642cfe18f10e60f78ea5c
return errors::Unimplemented(
"Restoring non-variable objects has not been implemented yet. ");
LOG(WARNING) << "Restoring non-variable objects has not been "
"implemented yet. (Kind="
<< bundle->saved_object_graph().nodes(node).kind_case()
<< ")";
return Status::OK();
}
Variable* variable =

View File

@ -38,8 +38,6 @@ cc_library(
":concrete_function_type",
":function_metadata",
":function_metadata_type",
":tensorhandle_list",
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c:tf_status_internal",
"//tensorflow/c/eager:abstract_tensor_handle",
@ -167,38 +165,6 @@ cc_library(
],
)
cc_library(
name = "tensorhandle_list",
srcs = [
"tensorhandle_list.cc",
],
hdrs = [
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list.h",
],
copts = tf_copts(),
visibility = [
"//tensorflow/c/experimental/saved_model/public:__pkg__",
],
deps = [
":tensorhandle_list_type",
"//tensorflow/c:c_api_macros",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
"//tensorflow/c/eager:tfe_tensorhandle_internal",
],
)
cc_library(
name = "tensorhandle_list_type",
hdrs = [
"tensorhandle_list_type.h",
],
deps = [
"//tensorflow/c:conversion_macros",
"//tensorflow/c/eager:immediate_execution_tensor_handle",
],
)
tf_cc_test(
name = "saved_model_api_test",
size = "small",
@ -216,7 +182,6 @@ tf_cc_test(
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/saved_model/public:concrete_function",
"//tensorflow/c/experimental/saved_model/public:saved_model_api",
"//tensorflow/c/experimental/saved_model/public:tensorhandle_list",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/core/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h"
#include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h"
#include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/platform/status.h"

View File

@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/experimental/saved_model/public/concrete_function.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/io/path.h"

View File

@ -24,7 +24,6 @@ exports_files(
"concrete_function_list.h",
"function_metadata.h",
"saved_model_api.h",
"tensorhandle_list.h",
],
visibility = ["//tensorflow/c/experimental/saved_model/internal:__pkg__"],
)
@ -40,7 +39,6 @@ cc_library(
":concrete_function_list",
":function_metadata",
":saved_model_api",
":tensorhandle_list",
],
)
@ -63,8 +61,3 @@ alias(
name = "saved_model_api",
actual = "//tensorflow/c/experimental/saved_model/internal:saved_model_api",
)
alias(
name = "tensorhandle_list",
actual = "//tensorflow/c/experimental/saved_model/internal:tensorhandle_list",
)

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/c/experimental/saved_model/public/concrete_function_list.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/saved_model_api.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
// IWYU pragma: end_exports
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_C_SAVED_MODEL_API_H_

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/experimental/saved_model/public/function_metadata.h"
#include "tensorflow/c/experimental/saved_model/public/tensorhandle_list.h"
#ifdef __cplusplus
extern "C" {

View File

@ -1,43 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_
#include <stddef.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
// An opaque type that is acts like a list of TF_ConcreteFunction pointers.
typedef struct TF_TensorHandleList TF_TensorHandleList;
// Returns the size of `list`.
TF_CAPI_EXPORT extern size_t TF_TensorHandleListSize(
const TF_TensorHandleList* list);
// Returns the `i`th TFE_TensorHandle in the list.
TF_CAPI_EXPORT extern TFE_TensorHandle* TF_TensorHandleListGet(
const TF_TensorHandleList* list, int i);
#ifdef __cplusplus
} // end extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_PUBLIC_TENSORHANDLE_LIST_H_

View File

@ -261,7 +261,6 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
size_t len, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
@ -279,4 +278,43 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index,
return nullptr;
}
return tf_tensor;
}
}
TF_Tensor* TF_AllocateTemp(TF_OpKernelContext* context, TF_DataType dtype,
int64_t* dims, int num_dims,
TF_AllocatorAttributes* attributes,
TF_Status* status) {
auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context);
TF_SetStatus(status, TF_OK, "");
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
tensorflow::gtl::ArraySlice<tensorflow::int64> dimarray(
reinterpret_cast<tensorflow::int64*>(dims), num_dims);
if (attributes && !attributes->struct_size) {
TF_SetStatus(
status, TF_INVALID_ARGUMENT,
"TF_AllocatorAttributes struct "
"size member must be set to TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE");
return nullptr;
}
tensorflow::AllocatorAttributes allocator_attr;
if (attributes && attributes->on_host) {
allocator_attr.set_on_host(true);
}
tensorflow::Status s;
tensorflow::Tensor tensor;
s = cc_ctx->allocate_temp(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimarray), &tensor,
allocator_attr);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
TF_Tensor* tf_tensor;
tf_tensor = TF_TensorFromTensor(tensor, &s);
if (!s.ok()) {
::tensorflow::Set_TF_Status_from_Status(status, s);
return nullptr;
}
return tf_tensor;
}

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
// Macro to control visibility of exported symbols in the shared library (.so,
// .dylib, .dll).
@ -199,6 +200,15 @@ TF_CAPI_EXPORT TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context,
int64_t* dims, int num_dims,
size_t len, TF_Status* status);
// Allocates a temporary Tensor of the specified type and shape. The
// Tensor must not be used after kernel construction is
// complete.
//
// num_dims must equal the size of array dims
TF_CAPI_EXPORT extern TF_Tensor* TF_AllocateTemp(
TF_OpKernelContext* context, TF_DataType dtype, int64_t* dims, int num_dims,
TF_AllocatorAttributes* alloc_attrs, TF_Status* status);
#ifdef __cplusplus
} /* end extern "C" */
#endif

View File

@ -38,6 +38,20 @@ tf_kernel_library(
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "histogram_summary_op",
prefix = "histogram_summary_op",
deps = [
"//tensorflow/c:kernels",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_tensor",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//third_party/eigen3",
],
)
tf_kernel_library(
name = "merge_summary_op",
@ -72,6 +86,15 @@ tf_gen_op_libs(
],
)
tf_gen_op_libs(
op_lib_names = ["histogram_summary"],
deps = [
"//tensorflow/c:ops",
"//tensorflow/c:tf_status",
"//tensorflow/core:lib",
],
)
tf_gen_op_libs(
op_lib_names = ["merge_summary"],
deps = [
@ -144,6 +167,7 @@ filegroup(
name = "android_all_op_kernels",
srcs = [
"bitcast_op.cc",
"histogram_summary_op.cc",
"merge_summary_op.cc",
"summary_op.cc",
"tensor_shape_utils.cc",
@ -156,6 +180,7 @@ filegroup(
name = "android_all_ops",
srcs = [
"ops/bitcast.cc",
"ops/histogram_summary.cc",
"ops/merge_summary.cc",
"ops/summary.cc",
],

View File

@ -0,0 +1,163 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <sstream>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/c/kernels.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
namespace {
// Operators used to create a std::unique_ptr for TF_Tensor and TF_Status.
struct TFTensorDeleter {
void operator()(TF_Tensor* tf_tensor) const { TF_DeleteTensor(tf_tensor); }
};
struct TFStatusDeleter {
void operator()(TF_Status* tf_status) const { TF_DeleteStatus(tf_status); }
};
// Struct that wraps TF_Tensor and TF_Status to delete once out of scope.
using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, TFTensorDeleter>;
using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, TFStatusDeleter>;
// Used to pass the operation node name from kernel construction to
// kernel computation.
struct HistogramSummaryOp {
std::string op_node_name;
};
void* HistogramSummaryOp_Create(TF_OpKernelConstruction* ctx) {
HistogramSummaryOp* kernel = new HistogramSummaryOp;
TF_StringView string_view_name = TF_OpKernelConstruction_GetName(ctx);
kernel->op_node_name =
std::string(string_view_name.data, string_view_name.len);
return kernel;
}
void HistogramSummaryOp_Delete(void* kernel) {
delete static_cast<HistogramSummaryOp*>(kernel);
}
template <typename T>
void HistogramSummaryOp_Compute(void* kernel, TF_OpKernelContext* ctx) {
HistogramSummaryOp* k = static_cast<HistogramSummaryOp*>(kernel);
TF_Tensor* tags;
TF_Tensor* values;
Safe_TF_StatusPtr status(TF_NewStatus());
TF_GetInput(ctx, 0, &tags, status.get());
Safe_TF_TensorPtr safe_tags_ptr(tags);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
TF_GetInput(ctx, 1, &values, status.get());
Safe_TF_TensorPtr safe_values_ptr(values);
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
if (TF_NumDims(safe_tags_ptr.get()) != 0) {
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, "tags must be scalar");
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
// Cast values to array to access tensor elements by index
auto values_array = static_cast<T*>(TF_TensorData(safe_values_ptr.get()));
tensorflow::histogram::Histogram histo;
for (int64_t i = 0; i < TF_TensorElementCount(safe_values_ptr.get()); ++i) {
const double double_val = static_cast<double>(values_array[i]);
if (Eigen::numext::isnan(double_val)) {
std::ostringstream err;
err << "Nan in summary histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
} else if (Eigen::numext::isinf(double_val)) {
std::ostringstream err;
err << "Infinity in Histogram for: " << k->op_node_name;
TF_SetStatus(status.get(), TF_INVALID_ARGUMENT, err.str().c_str());
return;
}
histo.Add(double_val);
}
tensorflow::Summary s;
tensorflow::Summary::Value* v = s.add_value();
const tensorflow::tstring& tag =
*(static_cast<tensorflow::tstring*>(TF_TensorData(safe_tags_ptr.get())));
v->set_tag(tag.data(), tag.size());
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
Safe_TF_TensorPtr summary_tensor(TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_ExpectedOutputDataType(ctx, 0),
/*dims=*/nullptr, /*num_dims=*/0,
/*len=*/sizeof(tensorflow::tstring), status.get()));
if (TF_GetCode(status.get()) != TF_OK) {
TF_OpKernelContext_Failure(ctx, status.get());
return;
}
tensorflow::tstring* output_tstring = reinterpret_cast<tensorflow::tstring*>(
TF_TensorData(summary_tensor.get()));
CHECK(SerializeToTString(s, output_tstring));
}
template <typename T>
void RegisterHistogramSummaryOpKernel() {
TF_Status* status = TF_NewStatus();
{
auto* builder = TF_NewKernelBuilder(
"HistogramSummary", tensorflow::DEVICE_CPU, &HistogramSummaryOp_Create,
&HistogramSummaryOp_Compute<T>, &HistogramSummaryOp_Delete);
TF_KernelBuilder_TypeConstraint(
builder, "T",
static_cast<TF_DataType>(tensorflow::DataTypeToEnum<T>::v()), status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << "Error while adding type constraint";
TF_RegisterKernelBuilder("HistogramSummary", builder, status);
CHECK_EQ(TF_OK, TF_GetCode(status))
<< "Error while registering Histogram Summmary kernel";
}
TF_DeleteStatus(status);
}
// A dummy static variable initialized by a lambda whose side-effect is to
// register the Histogram Summary kernel.
TF_ATTRIBUTE_UNUSED static bool IsHistogramSummaryOpKernelRegistered = []() {
if (SHOULD_REGISTER_OP_KERNEL("HistogramSummary")) {
RegisterHistogramSummaryOpKernel<tensorflow::int64>();
RegisterHistogramSummaryOpKernel<tensorflow::uint64>();
RegisterHistogramSummaryOpKernel<tensorflow::int32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint32>();
RegisterHistogramSummaryOpKernel<tensorflow::uint16>();
RegisterHistogramSummaryOpKernel<tensorflow::int16>();
RegisterHistogramSummaryOpKernel<tensorflow::int8>();
RegisterHistogramSummaryOpKernel<tensorflow::uint8>();
RegisterHistogramSummaryOpKernel<Eigen::half>();
RegisterHistogramSummaryOpKernel<tensorflow::bfloat16>();
RegisterHistogramSummaryOpKernel<float>();
RegisterHistogramSummaryOpKernel<double>();
}
return true;
}();
} // namespace

View File

@ -0,0 +1,50 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/ops.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/selective_registration.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
static void histogram_summary_shape_inference_fn(TF_ShapeInferenceContext* ctx,
TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
TF_ShapeHandle* result = TF_ShapeInferenceContextScalar(ctx);
TF_ShapeInferenceContextSetOutput(ctx, 0, result, status);
TF_DeleteShapeHandle(result);
}
void Register_HistogramSummaryOp() {
TF_Status* status = TF_NewStatus();
TF_OpDefinitionBuilder* op_builder =
TF_NewOpDefinitionBuilder("HistogramSummary");
TF_OpDefinitionBuilderAddInput(op_builder, "tag: string");
TF_OpDefinitionBuilderAddInput(op_builder, "values: T");
TF_OpDefinitionBuilderAddOutput(op_builder, "summary: string");
TF_OpDefinitionBuilderAddAttr(op_builder, "T: realnumbertype = DT_FLOAT");
TF_OpDefinitionBuilderSetShapeInferenceFunction(
op_builder, &histogram_summary_shape_inference_fn);
TF_RegisterOpDefinition(op_builder, status);
CHECK_EQ(TF_GetCode(status), TF_OK)
<< "HistogramSummary op registration failed: " << TF_Message(status);
TF_DeleteStatus(status);
}
TF_ATTRIBUTE_UNUSED static bool HistogramSummaryOpRegistered = []() {
if (SHOULD_REGISTER_OP("HistogramSummary")) {
Register_HistogramSummaryOp();
}
return true;
}();

View File

@ -368,6 +368,16 @@ class DeviceKernelOpTest : public OpsTestBase {
#endif
};
// Validates that the tensor has shape and type corresponding to
// dims and dtype.
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype);
// Copies data of length tensor_size_bytes from values to tensor.
template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx);
REGISTER_OP("AllocateOutputOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
@ -379,22 +389,11 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSizeOne) {
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*len=*/tensor_size_bytes, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(1, TF_NumDims(output));
EXPECT_EQ(1, TF_Dim(output, 0));
validate_tensor(output, &dim, 1, TF_FLOAT);
// Set output to 3
float* data = reinterpret_cast<float*>(TF_TensorData(output));
float value = 3.0f;
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, &value,
tensor_size_bytes);
#else
*data = value;
#endif
float values[1] = {3.0f};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -417,12 +416,8 @@ TEST_F(DeviceKernelOpTest, TestAllocateEmptyOutput) {
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*len=*/0, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(1, TF_NumDims(output));
EXPECT_EQ(0, TF_Dim(output, 0));
validate_tensor(output, &dim, 1, TF_FLOAT);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -442,27 +437,16 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
TF_Status* s = TF_NewStatus();
// Allocate 2x3 output
int64_t dim[2] = {2, 3};
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT) * 6;
TF_Tensor* output = TF_AllocateOutput(
/*context=*/ctx, /*index=*/0, /*dtype=*/TF_FLOAT, /*dims=*/dim,
/*num_dims=*/2, /*len=*/tensor_size_bytes, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
EXPECT_EQ(TF_FLOAT, TF_TensorType(output));
EXPECT_EQ(2, TF_NumDims(output));
EXPECT_EQ(2, TF_Dim(output, 0));
EXPECT_EQ(3, TF_Dim(output, 1));
validate_tensor(output, dim, 2, TF_FLOAT);
// Set output to [1 2 3 4 5 6]
void* data = TF_TensorData(output);
float value[6] = {1, 2, 3, 4, 5, 6};
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, value,
tensor_size_bytes);
#else
memcpy(data, value, tensor_size_bytes);
#endif
float values[6] = {1, 2, 3, 4, 5, 6};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
@ -474,4 +458,132 @@ TEST_F(DeviceKernelOpTest, TestAllocateOutputSize2x3) {
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp1").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempSizeOne) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
// Allocate scalar TF_Tensor
TF_Status* s = TF_NewStatus();
int64_t dim = 1;
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
size_t tensor_size_bytes = TF_DataTypeSize(TF_FLOAT);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, &dim, 1, TF_FLOAT);
// Set TF_Tensor value to 3
float values[1] = {3.0f};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp1", "AllocateTemp1", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [1] values: 3>",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp0").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempEmpty) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
// Allocate empty TF_Tensor
int64_t dim = 0;
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/&dim,
/*num_dims=*/1, /*allocator_attributes*/ &alloc_attrs, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, &dim, 1, TF_FLOAT);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp0", "AllocateTemp0", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [0] values: >",
output->DebugString(100));
}
REGISTER_OP("AllocateTempOp2x3").Output("output1: float");
TEST_F(DeviceKernelOpTest, TestAllocateTempSize2x3) {
auto my_compute_func = [](void* kernel, TF_OpKernelContext* ctx) {
TF_Status* s = TF_NewStatus();
size_t tensor_size_bytes = 6 * TF_DataTypeSize(TF_FLOAT);
// Allocate 2x3 TF_Tensor
int64_t dim[2] = {2, 3};
TF_AllocatorAttributes alloc_attrs;
alloc_attrs.struct_size = TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE;
#if GOOGLE_CUDA
alloc_attrs.on_host = 0;
#else
alloc_attrs.on_host = 1;
#endif
TF_Tensor* output = TF_AllocateTemp(
/*context=*/ctx, /*dtype=*/TF_FLOAT, /*dims=*/dim,
/*num_dims=*/2, /*allocator_attributes*/ &alloc_attrs, s);
EXPECT_EQ(TF_OK, TF_GetCode(s));
validate_tensor(output, dim, 2, TF_FLOAT);
// Set TF_Tensor values to [1 2 3 4 5 6]
float values[6] = {1, 2, 3, 4, 5, 6};
set_tensor_data<float>(output, values, tensor_size_bytes, ctx);
TF_SetOutput(ctx, 0, output, s);
TF_DeleteStatus(s);
TF_DeleteTensor(output);
};
SetupOp("AllocateTempOp2x3", "AllocateTempOp2x3", my_compute_func);
TF_ASSERT_OK(RunOpKernel());
Tensor* output = GetOutput(0);
EXPECT_EQ("Tensor<type: float shape: [2,3] values: [1 2 3][4 5 6]>",
output->DebugString(100));
}
void validate_tensor(TF_Tensor* tensor, int64_t* dims, int64_t num_dims,
TF_DataType dtype) {
EXPECT_EQ(TF_FLOAT, TF_TensorType(tensor));
EXPECT_EQ(num_dims, TF_NumDims(tensor));
for (int i = 0; i < num_dims; ++i) {
EXPECT_EQ(dims[i], TF_Dim(tensor, i));
}
}
template <typename T>
void set_tensor_data(TF_Tensor* tensor, T* values, size_t tensor_size_bytes,
TF_OpKernelContext* ctx) {
T* data = reinterpret_cast<T*>(TF_TensorData(tensor));
#if GOOGLE_CUDA
OpKernelContext* cc_ctx = reinterpret_cast<OpKernelContext*>(ctx);
cc_ctx->eigen_gpu_device().memcpyHostToDevice(data, values,
tensor_size_bytes);
#else
memcpy(data, values, tensor_size_bytes);
#endif
}
} // namespace tensorflow

View File

@ -288,7 +288,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, Status* status) {
if (!tensor.CopyFrom(src, src.shape())) {
return nullptr;
}
return new TF_Tensor{new tensorflow::TensorInterface(tensor)};
return new TF_Tensor{new tensorflow::TensorInterface(std::move(tensor))};
}
Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) {

View File

@ -19,6 +19,7 @@ limitations under the License.
#include <stdbool.h>
#include <stdint.h>
#include "tensorflow/c/c_api_macros.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
@ -45,6 +46,16 @@ limitations under the License.
extern "C" {
#endif
// Allocator Attributes used for tensor allocation.
typedef struct TF_AllocatorAttributes {
size_t struct_size;
// Set boolean to 1 for CPU allocation, else 0.
TF_Bool on_host;
} TF_AllocatorAttributes;
#define TF_ALLOCATOR_ATTRIBUTES_STRUCT_SIZE \
TF_OFFSET_OF_END(TF_AllocatorAttributes, on_host)
// --------------------------------------------------------------------------
// TF_Tensor holds a multi-dimensional array of elements of a single data type.
// For all types other than TF_STRING, the data buffer stores elements

View File

@ -558,6 +558,7 @@ tf_gen_op_wrappers_cc(
"io_ops",
"linalg_ops",
"list_ops",
"map_ops",
"logging_ops",
"lookup_ops",
"manip_ops",

View File

@ -128,22 +128,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_interpreter_device",
srcs = ["xla_interpreter_device.cc"],
visibility = [":friends"],
deps = [
":jit_compilation_passes",
":xla_device",
"//tensorflow/compiler/jit/kernels:xla_ops",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
"//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep
"@com_google_absl//absl/memory",
],
alwayslink = 1,
)
cc_library(
name = "xla_tensor",
srcs = ["xla_tensor.cc"],
@ -211,6 +195,7 @@ XLA_DEVICE_DEPS = [
"//tensorflow/core/kernels/data:optional_ops",
"//tensorflow/core/kernels/data:prefetch_dataset_op",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/stream_executor:tf_allocator_adapter",
"//tensorflow/stream_executor/platform",
]
@ -221,16 +206,19 @@ cc_library(
"xla_device.cc",
"xla_device_context.cc",
"xla_device_ops.cc",
"xla_ops_on_regular_devices.cc",
"xla_platform_info.cc",
],
hdrs = [
"xla_compile_on_demand_op.h",
"xla_device.h",
"xla_device_context.h",
"xla_device_ops.h",
"xla_platform_info.h",
],
# Public visibility is needed for external TF/XLA backends.
visibility = ["//visibility:public"],
deps = XLA_DEVICE_DEPS,
deps = XLA_DEVICE_DEPS + [":xla_compilation_cache"],
)
cc_library(
@ -394,20 +382,23 @@ cc_library(
alwayslink = 1,
)
# Linked by tensorflow core, without registration of jit compilation passes
# which is not necessary to create and run a XlaLocalLaunchBase kernel.
# Linking jit compilation passes could cause programs stuck right now (b/140069592).
cc_library(
name = "xla_kernel_creator_util",
name = "xla_kernel_creator",
srcs = [
"xla_kernel_creator_util.cc",
"xla_kernel_creator.cc",
"xla_kernel_creator.h",
],
visibility = [
":internal",
"//learning/brain/contrib/tpu_modeling/exp/tpu_inference_converter:__pkg__",
"//tensorflow/core/common_runtime/eager:__pkg__",
],
hdrs = ["xla_kernel_creator_util.h"],
visibility = ["//tensorflow/core/common_runtime/eager:__pkg__"],
deps = [
":common",
":compilability_check_util",
":compilation_passes",
":flags",
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
@ -422,25 +413,6 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "xla_kernel_creator",
srcs = [
"xla_kernel_creator.cc",
"xla_kernel_creator.h",
],
deps = [
":compilability_check_util",
":flags",
":jit_compilation_passes",
":xla_kernel_creator_util",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
],
alwayslink = 1,
)
tf_cc_test(
name = "xla_kernel_creator_test",
srcs = [

View File

@ -159,7 +159,7 @@ void AllocateAndParseFlags() {
device_flags = new XlaDeviceFlags;
device_flags->tf_xla_compile_on_demand = false;
device_flags->tf_xla_enable_xla_devices = true;
device_flags->tf_xla_enable_xla_devices = false;
ops_flags = new XlaOpsCommonFlags;
ops_flags->tf_xla_always_defer_compilation = false;
@ -268,10 +268,4 @@ void AppendMarkForCompilationPassFlags(std::vector<Flag>* flag_list) {
AppendMarkForCompilationPassFlagsInternal(flag_list);
}
static bool xla_is_enabled = false;
void SetXlaIsEnabled() { xla_is_enabled = true; }
bool IsXlaEnabled() { return xla_is_enabled; }
} // namespace tensorflow

View File

@ -162,14 +162,6 @@ MlirCommonFlags* GetMlirCommonFlags();
void AppendMarkForCompilationPassFlags(
std::vector<tensorflow::Flag>* flag_list);
// Makes all future calls to `IsXlaEnabled()` return `true`.
//
// Should only be called when XLA is linked in.
void SetXlaIsEnabled();
// Returns whether XLA is enabled.
bool IsXlaEnabled();
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_FLAGS_H_

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_activity_listener.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -63,38 +64,6 @@ namespace tensorflow {
namespace {
XlaPlatformInfo PlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
}
// A closure describing how to run a compiled version of a TensorFlow function.
//
@ -178,31 +147,6 @@ class XlaExecutableClosureStore {
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
// Return allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
// platform_info.
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
} // namespace
XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
@ -214,65 +158,9 @@ XlaLocalLaunchBase::XlaLocalLaunchBase(OpKernelConstruction* ctx,
constants_(constants),
resources_(resources),
function_(function),
platform_info_(PlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
has_ref_vars_(has_ref_vars) {}
static Status BuildCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
static Status CompileToLocalExecutable(
OpKernelContext* ctx, const NameAttrList& function, bool has_ref_vars,
const XlaPlatformInfo& platform_info,
@ -292,7 +180,7 @@ static Status CompileToLocalExecutable(
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
return BuildCompilationCache(ctx, platform_info, cache);
return BuildXlaCompilationCache(ctx, platform_info, cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
@ -302,32 +190,14 @@ static Status CompileToLocalExecutable(
*client = static_cast<xla::LocalClient*>(cache->client());
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options;
options.client = *client;
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
XlaCompiler::Options options = GenerateCompilerOptions(
*cache, ctx, platform_info, has_ref_vars, &tf_allocator_adapter);
std::map<int, Tensor> constant_args;
for (int i : constants) {
constant_args.insert({i, ctx->input(i)});
}
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
// Optimization: where possible, have the computation return a naked array
@ -503,7 +373,7 @@ XlaCompileOp::XlaCompileOp(OpKernelConstruction* ctx)
constants_(ConstantsVector(ctx)),
resources_(ResourcesVector(ctx)),
function_(FunctionAttr(ctx)),
platform_info_(PlatformInfoFromContext(ctx)),
platform_info_(XlaPlatformInfoFromContext(ctx)),
must_compile_(MustCompileAttr(ctx)),
has_ref_vars_(HasRefVars(ctx)) {}
@ -591,7 +461,7 @@ void XlaCompileOp::Compute(OpKernelContext* ctx) {
}
XlaRunOp::XlaRunOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(PlatformInfoFromContext(ctx)) {}
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
void XlaRunOp::Compute(OpKernelContext* ctx) {
VLOG(3) << "XlaRunOp " << def().name();

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -31,61 +32,6 @@ limitations under the License.
namespace tensorflow {
// Holds some information about the platform on which an
// XlaLaunch/_XlaCompile/_XlaRun op must run on.
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
bool UseMultipleStreams() const {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
// xla_device_metadata() is not nullptr.
se::Platform::Id platform_id() const { return platform_id_; }
// This may be null if the op this XlaPlatformInfo is for was not placed on an
// XLA device.
const XlaDevice::Metadata* xla_device_metadata() const {
return xla_device_metadata_;
}
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
private:
DeviceType device_type_;
se::Platform::Id platform_id_;
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
// XlaLocalLaunchBase is almost the same as XlaLocalLaunchOp.
// The only difference is that it does not require arguments to follow

View File

@ -44,6 +44,11 @@ using ::tensorflow::testing::FindNodeByName;
namespace tensorflow {
namespace {
static bool Initialized = [] {
tensorflow::GetXlaDeviceFlags()->tf_xla_enable_xla_devices = true;
return true;
}();
REGISTER_OP("UncompilableNullary").Output("o: float");
REGISTER_OP("UncompilableUnary").Input("a: float").Output("o: float");

View File

@ -406,37 +406,6 @@ TEST(PartiallyDeclusterPassTest, DontDeclusterXlaDeviceOps) {
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, DontDeclusterNonTensorFlowOps) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output dynamic_slice_operand =
ops::Placeholder(s.WithOpName("dynamic_slice_operand"), DT_INT32,
ops::Placeholder::Attrs{});
Output dynamic_slice_begin = ops::Placeholder(
s.WithOpName("dynamic_slice_begin"), DT_INT32, ops::Placeholder::Attrs{});
Output dynamic_slice_size = ops::Placeholder(
s.WithOpName("dynamic_slice_size"), DT_INT32, ops::Placeholder::Attrs{});
Output dynamic_slice =
ops::XlaDynamicSlice(s.WithOpName("dynamic_slice"), dynamic_slice_operand,
dynamic_slice_begin, dynamic_slice_size);
Output reshape_input = ops::Placeholder(s.WithOpName("reshape_input"),
DT_FLOAT, ops::Placeholder::Attrs{});
Output reshape =
ops::Reshape(s.WithOpName("reshape"), reshape_input, dynamic_slice);
AddToCluster({dynamic_slice.node(), reshape.node()}, "cluster_0");
std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
TF_ASSERT_OK(s.ToGraph(graph.get()));
Node* n = FindNodeByName(*graph, "dynamic_slice");
ASSERT_NE(n, nullptr);
TF_ASSERT_OK(PartiallyDecluster(&graph));
EXPECT_EQ(GetXlaClusterForNode(*n), "cluster_0");
}
TEST(PartiallyDeclusterPassTest, EliminatedUnusedNodes) {
const char* const kClusteredProducer0Name = "ClusteredProducer0";
const char* const kClusteredProducer1Name = "ClusteredProducer1";

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
@ -41,18 +42,21 @@ static std::vector<int> GetResourceVariableIndices(OpKernelContext* ctx) {
}
Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
const XlaDevice::Metadata& metadata,
XlaCompilationCache* cache,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args) {
xla::LocalClient* client = metadata.client();
xla::LocalClient* client = static_cast<xla::LocalClient*>(cache->client());
// Builds an XLA allocator for the device.
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator =
GetAllocator(&tf_allocator_adapter, ctx, platform_info_);
XlaComputationLaunchContext launch_context(
client, client->backend().memory_allocator(),
client->default_device_ordinal(),
/*allocate_xla_tensors=*/true,
/*use_multiple_streams=*/metadata.UseMultipleStreams());
client, allocator, client->default_device_ordinal(),
/*allocate_xla_tensors=*/platform_info_.xla_device_metadata() != nullptr,
platform_info_.xla_device_metadata()
? platform_info_.xla_device_metadata()->UseMultipleStreams()
: false);
std::map<int, const Tensor*> snapshot_ptrs;
for (auto& p : variable_args) {
@ -70,12 +74,11 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
VLOG(2) << "Executing computation: " << name();
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_allocator(client->backend().memory_allocator());
run_options.set_allocator(allocator);
run_options.set_intra_op_thread_pool(&ctx->eigen_cpu_device());
run_options.set_rng_seed(GetXLARandomSeed());
@ -94,71 +97,54 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
return Status::OK();
}
Status XlaCompileOnDemandOp::MustArgumentBeConstant(
const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime, bool* result) {
*result = false;
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache, ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
// TODO(jmolloy): This could be expensive, so memoize.
std::vector<int> constant_input_indices;
TF_RETURN_IF_ERROR(GetCompileTimeConstInputs(
op_kernel, &constant_input_indices, flib_runtime));
*result = absl::c_binary_search(constant_input_indices, argument_idx);
return Status::OK();
}
&ctx->op_kernel(), &constant_input_indices, ctx->function_library()));
CHECK(absl::c_is_sorted(constant_input_indices));
// TODO(ycao): Remove the need to call ShouldArgumentBeConstant. Its benefit is
// not clear yet and it causes heavy constant analysis to run twice.
Status XlaCompileOnDemandOp::ShouldArgumentBeConstant(
const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime, bool* result) {
return MustArgumentBeConstant(op_kernel, argument_idx, flib_runtime, result);
}
Status XlaCompileOnDemandOp::Compile(
OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
const XlaCompiler::CompilationResult** result,
ResourceVarsSnapshot* variable_args, xla::LocalExecutable** executable) {
std::map<int, Tensor> constant_arguments;
for (int64 i = 0; i < ctx->num_inputs(); ++i) {
const Tensor& device_tensor = ctx->input(i);
if (const XlaTensor* xla_tensor = XlaTensor::FromTensor(&device_tensor)) {
if (xla_tensor->has_host_tensor()) {
bool should_arg_be_const;
TF_RETURN_IF_ERROR(ShouldArgumentBeConstant(&ctx->op_kernel(), i,
ctx->function_library(),
&should_arg_be_const));
if (should_arg_be_const) {
if (absl::c_binary_search(constant_input_indices, i)) {
constant_arguments[i] = xla_tensor->host_tensor();
}
}
}
if (constant_arguments.count(i) == 0) {
bool must_argument_be_const;
TF_RETURN_IF_ERROR(MustArgumentBeConstant(&ctx->op_kernel(), i,
ctx->function_library(),
&must_argument_be_const));
if (must_argument_be_const) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(
device_tensor.dtype(), device_tensor.shape(), &host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
if (!constant_arguments.count(i)) {
if (absl::c_binary_search(constant_input_indices, i)) {
if (ctx->input_memory_type(i) != HOST_MEMORY &&
ctx->op_device_context()) {
// Slow path; the argument is not available as a host constant so we
// must fetch it synchronously.
Tensor host_tensor;
AllocatorAttributes attrs;
attrs.set_on_host(true);
TF_RETURN_IF_ERROR(ctx->allocate_temp(device_tensor.dtype(),
device_tensor.shape(),
&host_tensor, attrs));
Status status = ctx->op_device_context()->CopyDeviceTensorToCPUSync(
&device_tensor, "ConstantArgument",
reinterpret_cast<Device*>(ctx->device()), &host_tensor);
if (!status.ok()) {
LOG(ERROR) << "Copying tensor of shape "
<< device_tensor.shape().DebugString() << " from "
<< ctx->device()->name() << "to CPU failed with "
<< status.ToString();
return status;
}
constant_arguments[i] = host_tensor;
} else {
constant_arguments[i] = device_tensor;
}
constant_arguments[i] = host_tensor;
}
}
}
@ -168,24 +154,16 @@ Status XlaCompileOnDemandOp::Compile(
ResourceMgr* rm = ctx->resource_manager();
CHECK(rm);
XlaCompilationCache* cache;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
*cache = new XlaCompilationCache(metadata.client(),
metadata.jit_device_type());
return Status::OK();
rm->default_container(), "xla_cache", cache,
[&](XlaCompilationCache** write_into_cache) {
return BuildXlaCompilationCache(ctx, platform_info_, write_into_cache);
}));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
XlaCompiler::Options options;
options.device_type = metadata.jit_device_type();
options.client = metadata.client();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.shape_representation_fn = metadata.shape_representation_fn();
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
XlaCompiler::Options options =
GenerateCompilerOptions(**cache, ctx, platform_info_,
/*has_ref_vars=*/true, &tf_allocator_adapter);
XlaCompiler::CompileOptions compile_options;
compile_options.is_entry_computation = true;
@ -206,19 +184,25 @@ Status XlaCompileOnDemandOp::Compile(
constant_arguments, variable_infos, ctx, &args));
}
return cache->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
return (*cache)->CompileSingleOp(options, args, ctx, compile_options, result,
executable);
}
void XlaCompileOnDemandOp::Compute(OpKernelContext* ctx) {
const XlaCompiler::CompilationResult* result;
xla::LocalExecutable* executable;
const XlaDevice::Metadata* metadata;
OP_REQUIRES_OK(ctx, XlaDevice::GetMetadata(ctx, &metadata));
ResourceVarsSnapshot variable_args;
XlaCompilationCache* cache;
OP_REQUIRES(ctx, ctx->function_library(),
errors::Internal("Function library missing"));
OP_REQUIRES_OK(ctx,
Compile(ctx, *metadata, &result, &variable_args, &executable));
OP_REQUIRES_OK(ctx, Run(ctx, *metadata, result, executable, variable_args));
Compile(ctx, &result, &cache, &variable_args, &executable));
// Hold the reference to the JIT during evaluation. (We could probably
// free it sooner because the ResourceMgr will retain a reference, but
// this is more obviously correct.)
core::ScopedUnref cache_ref(cache);
OP_REQUIRES_OK(ctx, Run(ctx, cache, result, executable, variable_args));
}
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_launch_util.h"
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/core/framework/function.h"
@ -35,25 +36,24 @@ namespace tensorflow {
// vanilla TensorFlow op as long as the bridge supports it.
class XlaCompileOnDemandOp : public OpKernel {
public:
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
explicit XlaCompileOnDemandOp(OpKernelConstruction* ctx)
: OpKernel(ctx), platform_info_(XlaPlatformInfoFromContext(ctx)) {}
void Compute(OpKernelContext* ctx) override;
private:
XlaCompiler::Argument CreateCompilerArgument(OpKernelContext* ctx, int64 i);
Status ShouldArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime,
bool* result);
Status MustArgumentBeConstant(const OpKernel* op_kernel, int64 argument_idx,
FunctionLibraryRuntime* flib_runtime,
bool* result);
Status Compile(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
Status Compile(OpKernelContext* ctx,
const XlaCompiler::CompilationResult** result,
XlaCompilationCache** cache,
ResourceVarsSnapshot* variable_args,
xla::LocalExecutable** executable);
Status Run(OpKernelContext* ctx, const XlaDevice::Metadata& metadata,
Status Run(OpKernelContext* ctx, XlaCompilationCache* cache,
const XlaCompiler::CompilationResult* result,
xla::LocalExecutable* executable,
const ResourceVarsSnapshot& variable_args);
const XlaPlatformInfo platform_info_;
};
} // namespace tensorflow

View File

@ -61,6 +61,21 @@ limitations under the License.
namespace tensorflow {
// Default PaddedShapeFn implementation that simply returns the unpadded
// on-device shape. This is accurate for CPU and GPU devices that neither
// transpose nor pad tensors.
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
}
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
*shape = shaped_buffer.on_device_shape();
return Status::OK();
}
// Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
// XlaDeviceAllocator is created on demand and is associated with a
// XlaDevice. It outlives the device itself (for instance, the buffer
@ -116,20 +131,6 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
namespace {
// Default PaddedShapeFn implementation that simply returns the unpadded
// on-device shape. This is accurate for CPU and GPU devices that neither
// transpose nor pad tensors.
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape) {
const tensorflow::XlaTensor* xla_tensor =
tensorflow::XlaTensor::FromTensor(&tensor);
if (xla_tensor == nullptr) {
return TensorShapeToXLAShape(tensor.dtype(), tensor.shape(), shape);
}
const xla::ShapedBuffer& shaped_buffer = xla_tensor->shaped_buffer();
*shape = shaped_buffer.on_device_shape();
return Status::OK();
}
static DeviceAttributes BuildXlaDeviceAttributes(const string& name_prefix,
const string& device_name,

View File

@ -280,6 +280,8 @@ struct XlaDeviceOpRegistrations {
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device);
Status DefaultPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_H_

View File

@ -1,106 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Registers the XLA_INTERPRETER device which exposes the XLA Interpreter.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/compiler/jit/xla_device_ops.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
namespace tensorflow {
const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER";
const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT";
constexpr std::array<DataType, 10> kExecAllTypes = {
{DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
class XlaInterpreterDeviceFactory : public DeviceFactory {
public:
Status ListPhysicalDevices(std::vector<string>* devices) override;
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) override;
};
Status XlaInterpreterDeviceFactory::ListPhysicalDevices(
std::vector<string>* devices) {
devices->push_back(
absl::StrCat("/physical_device:", DEVICE_XLA_INTERPRETER, ":0"));
return Status::OK();
}
Status XlaInterpreterDeviceFactory::CreateDevices(
const SessionOptions& session_options, const string& name_prefix,
std::vector<std::unique_ptr<Device>>* devices) {
static XlaDeviceOpRegistrations* registrations = RegisterXlaDeviceKernels(
DEVICE_XLA_INTERPRETER, DEVICE_INTERPRETER_XLA_JIT);
(void)registrations;
XlaOpRegistry::DeviceRegistration registration;
registration.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
registration.autoclustering_policy =
XlaOpRegistry::AutoclusteringPolicy::kAlways;
registration.cluster_resource_variable_ops_unsafely = true;
registration.cluster_stack_ops = false;
registration.cluster_tensor_array_ops = true;
registration.cluster_stateful_rng_ops = true;
registration.cluster_control_trigger = true;
registration.elide_assert_and_checknumerics = true;
registration.cluster_variant_ops = true;
registration.cluster_slow_ops = true;
registration.cluster_inaccurate_ops = true;
XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_INTERPRETER,
registration);
TF_ASSIGN_OR_RETURN(
auto platform, se::MultiPlatformManager::PlatformWithName("Interpreter"));
XlaDevice::Options options;
options.platform = platform;
options.device_name_prefix = name_prefix;
options.device_name = DEVICE_XLA_INTERPRETER;
options.device_ordinal = 0;
options.compilation_device_name = DEVICE_INTERPRETER_XLA_JIT;
options.use_multiple_streams = false;
devices->push_back(absl::make_unique<XlaDevice>(session_options, options));
return Status::OK();
}
// Set priority to be below the default priority (50), so that Interpreter is
// not selected as a high priority device over other default devices. See
// constructor comments for Registrar in
// tensorflow/core/common_runtime/device_factory.h for a list of priority for
// devices.
REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_INTERPRETER,
XlaInterpreterDeviceFactory, 40);
// Kernel registrations
static bool OpFilter(KernelDef* kdef) { return true; }
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_INTERPRETER, XlaLocalLaunchOp,
kExecAllTypes);
REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_INTERPRETER, XlaCompileOp,
kExecAllTypes);
REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_INTERPRETER, XlaRunOp, kExecAllTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_INTERPRETER, kExecAllTypes);
REGISTER_XLA_BACKEND(DEVICE_INTERPRETER_XLA_JIT, kExecAllTypes, OpFilter);
} // namespace tensorflow

View File

@ -14,10 +14,62 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // end namespace
namespace tensorflow {
@ -27,6 +79,121 @@ bool XlaKernelCreator::CanCreateKernel(
return CanCreateXlaKernel(props->node_def);
}
static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_info;
for (const auto& it : uncompilable_nodes_map) {
for (const auto& info : it.second.second) {
uncompilable_node_info.emplace_back(info);
}
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message = absl::StrCat("\n", node_info.name, ": ",
node_info.uncompilable_reason, "\n",
"\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
}
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
return errors::InvalidArgument(message);
}
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function,
/*has_ref_vars=*/false);
return s;
}
Status XlaKernelCreator::CreateKernel(
FunctionLibraryRuntime* flr,
const std::shared_ptr<const NodeProperties>& props,
@ -34,19 +201,12 @@ Status XlaKernelCreator::CreateKernel(
return CreateXlaKernel(flr, props->node_def, kernel);
}
namespace {
bool RegisterLaunchOpCreator() {
static bool RegisterLaunchOpCreator() {
XlaKernelCreator* xla_kernel_creator = new XlaKernelCreator();
RegisterDefaultCustomKernelCreator(xla_kernel_creator);
return true;
}
static bool register_me = RegisterLaunchOpCreator();
static bool register_xla = [] {
SetXlaIsEnabled();
return true;
}();
} // end namespace
} // namespace tensorflow

View File

@ -1,186 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator_util.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "tensorflow/compiler/jit/compilability_check_util.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
// Utility which searches for values in a sorted list by scanning over it once.
// No matter how many times ScanForValue is called, the list is scanned at most
// once. However, if a call to ScanForValue skips over a value, that value is
// not revisited in future calls to ScanForValue, so callers must take
// care to order their calls.
//
// Useful for merging multiple sorted lists in O(n) time.
class SinglePassSearch {
public:
// Creates a SinglePassSearch object that can be used to search in `values`.
// Does not take ownership of `values`. `values` must outlive this.
// `values` must be sorted.
explicit SinglePassSearch(const std::vector<int>* values)
: current_index_(0), values_(values) {}
// Scans forward in the vector looking for "value", updating the internal
// position in to the vector.
// Returns true iff the vector contains the given value at or after current
// position.
// Not thread-safe.
bool ScanForValue(int value) {
while (current_index_ < values_->size() &&
(*values_)[current_index_] <= value) {
if ((*values_)[current_index_] == value) {
current_index_++;
return true;
}
current_index_++;
}
return false;
}
private:
int current_index_;
const std::vector<int>* values_;
};
} // namespace
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel) {
if (!CanCreateXlaKernel(node_def)) {
return errors::Internal("Invalid node: ", node_def.ShortDebugString());
}
VLOG(3) << "Attempting to create XlaLaunchOp for " << node_def.DebugString();
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
// Only check for compilability if the MLIR bridge is not enabled.
if (!GetMlirCommonFlags()->tf_mlir_enable_mlir_bridge) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
uncompilable_node_info;
for (const auto& it : uncompilable_nodes_map) {
for (const auto& info : it.second.second) {
uncompilable_node_info.emplace_back(info);
}
}
string message = absl::StrCat(
"Function invoked by the following node is not compilable: ",
SummarizeNodeDef(node_def, /*max_inputs_in_summary=*/10), ".\n");
absl::StrAppend(&message, "Uncompilable nodes:");
for (const auto& node_info : uncompilable_node_info) {
string node_message = absl::StrCat("\n", node_info.name, ": ",
node_info.uncompilable_reason, "\n",
"\tStacktrace:\n");
for (const auto& stack_frame : node_info.stack_trace) {
absl::StrAppendFormat(&node_message, "\t\tNode: %s, function: %s\n",
stack_frame.name, stack_frame.function_name);
}
absl::StrAppend(&message, node_message);
}
VLOG(1) << message;
return errors::InvalidArgument(message);
}
}
// Get function body, constant args, and resource args.
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, node_def, &fbody, &constant_arg_indices, &resource_arg_indices));
// Set input and output memory types.
MemoryTypeVector input_memory_types(fbody->arg_types.size(), DEVICE_MEMORY);
// These indices are used only for optimization purposes. They allow us
// to loop over constant_arg_indices and resource_arg_indices only once
// while iterating over all the function arguments checking if it is a
// resource or a constant.
// The reason we optimized this code is because functions can have a lot of
// captured arguments. For example, the backward pass of ResNet50 takes in all
// 214 variables and a similar number of activations.
SinglePassSearch constants_search(&constant_arg_indices);
SinglePassSearch resources_search(&resource_arg_indices);
for (size_t i = 0; i < fbody->arg_types.size(); ++i) {
if (resources_search.ScanForValue(i) || constants_search.ScanForValue(i)) {
// Compile-time constants and resource handles are expected to be in
// host memory.
input_memory_types[i] = HOST_MEMORY;
}
}
// One might wonder, about the case where a compile-time constant argument
// (which must be in host memory) is also used as an input into an op,
// e.g. Add, that expects its inputs in device memory. Here is how it
// works now.
// First, what do we mean by "op expects an input in XYZ memory"?
// There are two types of "ops" here: the tf2xla kernel and the HLO
// computation it builds. The tf2xla kernel needs to retrieve the actual
// numeric value of the compile-time constant tensors, so it really expects
// them to be on in host memory. However, for other inputs, it refers to them
// using xla::ComputationDataHandle, which is just a symbolic handle that
// xla::ComputationBuilder assigns. How does this handle gets assigned for
// constant arguments? Even constant arguments get an _Arg node in the graph
// instantiated for Function compilation. The tf2xla kernel for constant _Arg
// nodes takes the constant value, converts it to XlaLiteral, and feeds it
// to xla::ComputationBuilder.ConstantLiteral, which returns the handle. This
// constant XlaLiteral is included in the HLO graph, and subsequently, in
// the actual executable, which is copied to the device before being
// executed. Thus, when this executable runs, the constant is available in
// device memory.
// XlaLaunch kernel keeps all outputs (including constants, which it copies),
// in device memory except for resources.
MemoryTypeVector output_memory_types(fbody->ret_types.size(), DEVICE_MEMORY);
for (size_t i = 0; i < fbody->ret_types.size(); ++i) {
if (fbody->ret_types[i] == DT_RESOURCE) {
output_memory_types[i] = HOST_MEMORY;
}
}
// Create the kernel.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
Device* dev = flr->device();
Status s;
auto props = std::make_shared<NodeProperties>(
&fbody->fdef.signature(), node_def, fbody->arg_types, fbody->ret_types);
OpKernelConstruction construction(DeviceType(dev->device_type()), dev,
dev->GetAllocator(AllocatorAttributes()),
flr, dev->resource_manager(), props,
input_memory_types, output_memory_types,
flr->graph_def_version(), &s);
*kernel = absl::make_unique<XlaLocalLaunchBase>(
&construction, constant_arg_indices, resource_arg_indices, function,
/*has_ref_vars=*/false);
return s;
}
} // namespace tensorflow

View File

@ -1,33 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
#define TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
class FunctionLibraryRuntime;
class OpKernel;
// Given a supported NodeDef, returns a XlaLaunchOp that computes the node.
Status CreateXlaKernel(FunctionLibraryRuntime* flr, const NodeDef& node_def,
std::unique_ptr<OpKernel>* kernel);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_KERNEL_CREATOR_UTIL_H_

View File

@ -0,0 +1,89 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Register XlaXXX operations on regular CPU/GPU devices using
// `XlaCompileOnDemandOp`.
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
#define REGISTER_XLA_OPS_ON_DEVICE(DEVICE) \
REGISTER_KERNEL_BUILDER(Name("XlaConv") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.HostMemory("lhs_dilation") \
.HostMemory("rhs_dilation") \
.HostMemory("feature_group_count") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER( \
Name("XlaBroadcastHelper").HostMemory("broadcast_dims").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSelfAdjointEig").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSvd").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDot").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDynamicSlice").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDynamicUpdateSlice").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaIf").Device(DEVICE), XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaPad").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaRecv").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduce").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReduceWindow").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSelectAndScatter") \
.HostMemory("window_dimensions") \
.HostMemory("window_strides") \
.HostMemory("padding") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSend").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaKeyValueSort").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaWhile").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaDequantize").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaEinsum").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSpmdShardToFullShape").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaSharding").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaReplicaId").Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaGather") \
.HostMemory("start_indices") \
.HostMemory("slice_sizes") \
.Device(DEVICE), \
XlaCompileOnDemandOp); \
REGISTER_KERNEL_BUILDER(Name("XlaScatter").Device(DEVICE), \
XlaCompileOnDemandOp);
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_CPU);
REGISTER_XLA_OPS_ON_DEVICE(DEVICE_GPU);
} // namespace tensorflow

View File

@ -0,0 +1,159 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_platform_info.h"
#include "tensorflow/compiler/xla/client/client_library.h"
namespace tensorflow {
Status BuildXlaCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache) {
if (platform_info.xla_device_metadata()) {
*cache = new XlaCompilationCache(
platform_info.xla_device_metadata()->client(),
platform_info.xla_device_metadata()->jit_device_type());
return Status::OK();
}
auto platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id());
if (!platform.ok()) {
return platform.status();
}
xla::StatusOr<xla::Compiler*> compiler_for_platform =
xla::Compiler::GetForPlatform(platform.ValueOrDie());
if (!compiler_for_platform.ok()) {
// In some rare cases (usually in unit tests with very small clusters) we
// may end up transforming an XLA cluster with at least one GPU operation
// (which would normally force the cluster to be compiled using XLA:GPU)
// into an XLA cluster with no GPU operations (i.e. containing only CPU
// operations). Such a cluster can fail compilation (in way that
// MarkForCompilation could not have detected) if the CPU JIT is not linked
// in.
//
// So bail out of _XlaCompile in this case, and let the executor handle the
// situation for us.
const Status& status = compiler_for_platform.status();
if (status.code() == error::NOT_FOUND) {
return errors::Unimplemented("Could not find compiler for platform ",
platform.ValueOrDie()->Name(), ": ",
status.ToString());
}
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
client_options.set_intra_op_parallelism_threads(
ctx->device()->tensorflow_cpu_worker_threads()->num_threads);
auto client = xla::ClientLibrary::GetOrCreateLocalClient(client_options);
if (!client.ok()) {
return client.status();
}
const XlaOpRegistry::DeviceRegistration* registration;
if (!XlaOpRegistry::GetCompilationDevice(platform_info.device_type().type(),
&registration)) {
return errors::InvalidArgument("No JIT device registered for ",
platform_info.device_type().type());
}
*cache = new XlaCompilationCache(
client.ValueOrDie(), DeviceType(registration->compilation_device_name));
return Status::OK();
}
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx) {
DeviceType device_type = ctx->device_type();
se::Platform::Id platform_id = nullptr;
const XlaDevice::Metadata* xla_device_metadata = nullptr;
se::DeviceMemoryAllocator* custom_allocator = nullptr;
if (ctx->device_type() == DeviceType(DEVICE_CPU)) {
platform_id = se::host::kHostPlatformId;
} else if (ctx->device_type() == DeviceType(DEVICE_GPU)) {
platform_id = ctx->device()
->tensorflow_gpu_device_info()
->stream->parent()
->platform()
->id();
} else if (XlaDevice::GetMetadata(ctx, &xla_device_metadata).ok()) {
// If we are on an XlaDevice, use the underlying XLA platform's allocator
// directly. We could use the StreamExecutor's allocator which may
// theoretically be more correct, but XLA returns a nice OOM message in a
// Status and StreamExecutor does not.
//
// Importantly we can't use ctx->device()->GetAllocator() as the allocator
// (which xla_allocator above uses) as on an XlaDevice, this is a dummy
// allocator that returns XlaTensor objects. The XlaCompiler needs a real
// allocator to allocate real buffers.
platform_id = xla_device_metadata->platform()->id();
custom_allocator =
xla_device_metadata->client()->backend().memory_allocator();
}
return XlaPlatformInfo(device_type, platform_id, xla_device_metadata,
custom_allocator);
}
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info) {
if (platform_info.custom_allocator()) {
return platform_info.custom_allocator();
}
if (!ctx->op_device_context()) {
// Stream is not set for the host platform.
se::Platform* platform =
se::MultiPlatformManager::PlatformWithId(platform_info.platform_id())
.ValueOrDie();
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}), platform);
return &tf_allocator_adapter->value();
}
tf_allocator_adapter->emplace(ctx->device()->GetAllocator({}),
ctx->op_device_context()->stream());
return &tf_allocator_adapter->value();
}
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter) {
CHECK(ctx->function_library());
XlaCompiler::Options options;
options.client = static_cast<xla::LocalClient*>(cache.client());
if (ctx->op_device_context() != nullptr) {
options.device_ordinal =
ctx->op_device_context()->stream()->parent()->device_ordinal();
}
options.device_type = cache.device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
options.allow_cpu_custom_calls =
(platform_info.platform_id() == se::host::kHostPlatformId);
options.device_allocator =
GetAllocator(tf_allocator_adapter, ctx, platform_info);
if (platform_info.xla_device_metadata()) {
options.shape_representation_fn =
platform_info.xla_device_metadata()->shape_representation_fn();
}
// If reference variables are not present in the graph, we can safely alias
// passthrough parameters without performing a copy.
options.alias_passthrough_params =
!has_ref_vars && !platform_info.is_on_xla_device();
return options;
}
} // namespace tensorflow

View File

@ -0,0 +1,108 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
#define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
#include "tensorflow/compiler/jit/xla_compilation_cache.h"
#include "tensorflow/compiler/jit/xla_device.h"
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
namespace tensorflow {
// Holds some information about the platform on which an
// XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of
// abstraction for normal and XLA devices.
class XlaPlatformInfo {
public:
XlaPlatformInfo() : device_type_("") {}
XlaPlatformInfo(XlaPlatformInfo&&) = default;
explicit XlaPlatformInfo(const DeviceType device_type,
se::Platform::Id platform_id,
const XlaDevice::Metadata* xla_device_metadata,
se::DeviceMemoryAllocator* device_allocator)
: device_type_(device_type),
platform_id_(platform_id),
xla_device_metadata_(xla_device_metadata),
device_allocator_(device_allocator) {}
XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
bool UseMultipleStreams() const {
return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
}
// Non-null only when run on an XLA device.
se::DeviceMemoryAllocator* custom_allocator() const {
return device_allocator_;
}
DeviceType device_type() const { return device_type_; }
// This is equal to xla_device_metadata()->platform()->id() if
// xla_device_metadata() is not nullptr.
se::Platform::Id platform_id() const { return platform_id_; }
// This may be null if the op this XlaPlatformInfo is for was not placed on an
// XLA device.
const XlaDevice::Metadata* xla_device_metadata() const {
return xla_device_metadata_;
}
bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
private:
DeviceType device_type_;
se::Platform::Id platform_id_;
// xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
// XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
// XlaLaunch/_XlaCompile/_XlaRun OpKernel.
const XlaDevice::Metadata* xla_device_metadata_;
// If the op associated with this XlaPlatformInfo is placed on an XLA device
// then device_allocator_ is the xla::Backend's memory allocator. If the op
// is placed on a regular CPU or GPU device then device_allocator_ is null.
se::DeviceMemoryAllocator* device_allocator_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
};
// Returns created XLA compilation cache.
Status BuildXlaCompilationCache(OpKernelContext* ctx,
const XlaPlatformInfo& platform_info,
XlaCompilationCache** cache);
// Returns information about the platform from kernel context.
XlaPlatformInfo XlaPlatformInfoFromContext(OpKernelConstruction* ctx);
// Returns allocator from platform info if non-null, or populate and return a
// pointer to the allocator adapter with allocator from context.
//
// This is necessary because for XLA devices the underlying TF allocator returns
// dummy tensors.
se::DeviceMemoryAllocator* GetAllocator(
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
OpKernelContext* ctx, const XlaPlatformInfo& platform_info);
// Returns created options for the XLA compiler, and writes the used allocator
// into `tf_allocator_adapter`.
XlaCompiler::Options GenerateCompilerOptions(
const XlaCompilationCache& cache, OpKernelContext* ctx,
const XlaPlatformInfo& platform_info, bool has_ref_vars,
absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_

View File

@ -0,0 +1,4 @@
build
llvm-project
llvm-build

View File

@ -404,6 +404,7 @@ cc_library(
cc_library(
name = "lhlo_legalize_to_llvm",
srcs = ["lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc"],
hdrs = ["include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h"],
deps = [
":lhlo",
"@llvm-project//mlir:IR",
@ -759,8 +760,6 @@ cc_library(
":lhlo_legalize_to_llvm", # build-cleaner: keep
":materialize_broadcasts", # build-cleaner: keep
":unfuse_batch_norm", # build-cleaner: keep
"@llvm-project//mlir:AffineToStandardTransforms",
"@llvm-project//mlir:CFGTransforms",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:LLVMDialect",
@ -807,13 +806,6 @@ cc_library(
],
)
cc_library(
name = "register_all_passes",
srcs = ["lib/Dialect/mhlo/transforms/register_all_passes.cc"],
deps = [":all_passes"],
alwayslink = 1,
)
cc_binary(
name = "mlir-hlo-opt",
srcs = [

View File

@ -0,0 +1,94 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cmake_minimum_required(VERSION 3.13.4)
if(POLICY CMP0068)
cmake_policy(SET CMP0068 NEW)
set(CMAKE_BUILD_WITH_INSTALL_NAME_DIR ON)
endif()
if(POLICY CMP0075)
cmake_policy(SET CMP0075 NEW)
endif()
if(POLICY CMP0077)
cmake_policy(SET CMP0077 NEW)
endif()
#-------------------------------------------------------------------------------
# Project setup and globals
#-------------------------------------------------------------------------------
project(mlir-hlo LANGUAGES CXX C)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_STANDARD 14)
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
#-------------------------------------------------------------------------------
# Options and settings
#-------------------------------------------------------------------------------
#-------------------------------------------------------------------------------
# MSVC defaults
#-------------------------------------------------------------------------------
if(MSVC)
add_compile_options(
$<$<CONFIG:>:/MD>
$<$<CONFIG:Debug>:/MD>
$<$<CONFIG:Release>:/MD>
)
endif()
#-------------------------------------------------------------------------------
# MLIR/LLVM Configuration
#-------------------------------------------------------------------------------
find_package(MLIR REQUIRED CONFIG)
message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
if(LLVM_ENABLE_ZLIB)
find_package(ZLIB)
endif()
include(TableGen)
include(AddLLVM)
include(AddMLIR)
include(HandleLLVMOptions)
include_directories(${LLVM_INCLUDE_DIRS})
include_directories(${MLIR_INCLUDE_DIRS})
include_directories(${PROJECT_SOURCE_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/include)
include_directories(${PROJECT_BINARY_DIR}/)
link_directories(${LLVM_BUILD_LIBRARY_DIR})
add_definitions(${LLVM_DEFINITIONS})
#-------------------------------------------------------------------------------
# Directory setup
#-------------------------------------------------------------------------------
set(MLIR_HLO_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(MLIR_HLO_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
add_custom_target(check-mlir-hlo)
add_subdirectory(include/mlir-hlo)
add_subdirectory(lib)
add_subdirectory(tools)
add_subdirectory(tests)

View File

@ -1,4 +1,4 @@
# MLIR-HLO
# MLIR-HLO: A Standalone "HLO" MLIR-based Compiler
The code here exists in two places:
@ -22,10 +22,43 @@ upstream.
## QuickStart: building and testing
TODO
These instructions work on Linux, you may have to adjust for your plaform.
To build the code in this repository, you need a clone of the LLVM/MLIR git
repository:
$ git clone https://github.com/llvm/llvm-project.git
You need to make sure you have the right commit checked out in the LLVM
repository (you need to do this every time you pull from this repo):
$ (cd llvm-project && git checkout $(cat build_tools/llvm_version.txt))
We provide a script to configure and build LLVM/MLIR:
$ build_tools/build_mlir.sh ${PWD}/llvm-project/ ${PWD}/llvm-build
Again this is something to do every time you pull from this repository and the
LLVM revision changes.
Finally you can build and test this repository:
$ mkdir build && cd build
$ cmake .. -GNinja \
-DLLVM_ENABLE_LLD=ON \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=On \
-DMLIR_DIR=${PWD}/../llvm-build/lib/cmake/mlir
$ ninja check-mlir-hlo
## Overview
MLIR-HLO aims to provide an end-to-end compiler for CPU and GPU, as well as
building reusable blocks for other accelerators. This is heavily inspired by the
success of XLA.
[XLA](https://www.tensorflow.org/xla/) (Accelerated Linear Algebra) is a
domain-specific compiler framework and execution environment for linear algebra,
which powers code-generation for ML frameworks like TensorFlow, JAX, and others.

View File

@ -0,0 +1,16 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(Dialect)

View File

@ -0,0 +1,16 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(mhlo)

View File

@ -0,0 +1,17 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(IR)
add_subdirectory(transforms)

View File

@ -0,0 +1,31 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Need a separate function because of the .cc vs .cpp used in the one provided by MLIR
function(add_mlir_hlo_dialect dialect dialect_namespace)
set(LLVM_TARGET_DEFINITIONS ${dialect}.td)
mlir_tablegen(${dialect}.h.inc -gen-op-decls)
mlir_tablegen(${dialect}.cc.inc -gen-op-defs)
mlir_tablegen(${dialect}_structs.h.inc -gen-struct-attr-decls)
mlir_tablegen(${dialect}_structs.cc.inc -gen-struct-attr-defs)
add_public_tablegen_target(MLIR${dialect}IncGen)
add_dependencies(mlir-headers MLIR${dialect}IncGen)
endfunction()
add_mlir_hlo_dialect(chlo_ops chlo)
add_mlir_hlo_dialect(hlo_ops mhlo)
add_mlir_hlo_dialect(lhlo_ops lmhlo)
add_mlir_interface(infer_fusibility_op_interface)

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
@ -32,14 +33,33 @@ namespace mlir {
namespace chlo {
class HloClientDialect : public Dialect {
void initialize();
public:
explicit HloClientDialect(MLIRContext *context);
explicit HloClientDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context,
TypeID::get<HloClientDialect>()) {
initialize();
}
static StringRef getDialectNamespace() { return "chlo"; }
};
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc"
template <typename T>
static Value getConstantLike(OpBuilder& b, T constant, Value val) {
Type ty = getElementTypeOrSelf(val.getType());
auto getAttr = [&]() -> Attribute {
if (ty.isa<IntegerType>()) return b.getIntegerAttr(ty, constant);
if (ty.isa<FloatType>()) return b.getFloatAttr(ty, constant);
llvm_unreachable("unhandled element type");
};
// TODO(jpienaar): Add ability to pass loc via native call and update.
return b.create<ConstantLikeOp>(b.getUnknownLoc(), getAttr(), val);
}
} // namespace chlo
} // namespace mlir

View File

@ -364,6 +364,24 @@ def HLOClient_AcosOp: HLOClient_UnaryElementwiseOp<"acos",
}];
}
def HLOClient_ConstantLikeOp: HLOClient_Op<"constant_like",
[NoSideEffect, SameOperandsAndResultShape,
InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface>,
NativeOpTrait<"InferTensorType">]> {
let summary = "Constant like operator";
let description = [{
Returns a splat constant of the same shape as the operand.
}];
// TODO(jpienaar): value's type could be tightened.
let arguments = (ins AnyAttr:$value, HLO_Tensor:$operand);
let results = (outs HLO_Tensor);
let hasCanonicalizer = 1;
}
//===----------------------------------------------------------------------===//
// Broadcasting compare op
//===----------------------------------------------------------------------===//

View File

@ -69,9 +69,6 @@ class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
static TokenType get(MLIRContext *context) {
return Base::get(context, HLOTypes::Token);
}
// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) { return kind == HLOTypes::Token; }
};
// Shape derivation function that computes the shape of the result based on

View File

@ -67,8 +67,7 @@ def HLO_ConstOp : HLO_Op<"constant",
"OpBuilder &builder, OperationState &result, Attribute value"
>];
let printer = [{ return Print(*this, &p); }];
let parser = [{ return ParseConstOp(&parser, &result); }];
let assemblyFormat = "attr-dict $value";
let hasFolder = 1;
@ -671,6 +670,7 @@ def HLO_TupleOp : HLO_Op<"tuple", [NoSideEffect]>, BASE_HLO_TupleOp {
"OpBuilder &builder, OperationState &results, "
"ValueRange values">];
let hasCanonicalizer = 1;
}
def HLO_CompareOp: HLO_Op<"compare",
@ -1329,8 +1329,9 @@ def HLO_TorchIndexSelectOp : HLO_Op<"torch_index_select", [NoSideEffect]> {
}
//===----------------------------------------------------------------------===//
// MHLO RngUniform Operator.
// MHLO RNG Operators.
//===----------------------------------------------------------------------===//
def HLO_RngUniformOp : HLO_Op<"rng_uniform", []>, BASE_HLO_RngUniformOp {
let arguments = (ins
HLO_PredIntOrFpTensor:$a,
@ -1355,6 +1356,19 @@ def HLO_RngNormalOp : HLO_Op<"rng_normal", []>, BASE_HLO_RngNormalOp {
let hasCustomHLOConverter = 1;
}
def HLO_RngBitGeneratorOp : HLO_Op<"rng_bit_generator", [NoSideEffect]>, BASE_HLO_RngBitGeneratorOp {
let arguments = (ins
// TODO(jpienaar): This could be an enum instead.
I32Attr:$rng_algorithm,
HLO_IntOrFpTensor:$initial_state
);
let results = (outs HLO_TensorOrTuple:$result);
// TODO(jpienaar): This should not be needed.
let hasCustomHLOConverter = 1;
}
//===----------------------------------------------------------------------===//
// MHLO Quantize Operator.
//===----------------------------------------------------------------------===//

View File

@ -316,6 +316,19 @@ class BASE_HLO_RealOp {
}];
}
class BASE_HLO_RngBitGeneratorOp {
string summary = "Uniform random number generator operator";
string description = [{
Returns an output with a given shape filled with uniform random bits using
the specified algorithm (or backend default) and returns an updated state
(with the same shape as initial state) and the generated random data.
See
https://www.tensorflow.org/xla/operation_semantics#rngbitgenerator.
}];
}
class BASE_HLO_RoundOp {
string summary = "Round operator";

View File

@ -27,6 +27,9 @@ def CastIntElementsAttr : NativeCodeCall<"$0.cast<DenseIntElementsAttr>()">;
class ConstantSplat<string value> : NativeCodeCall<
"hlo::getSplat(&$_builder, $0, " # value # ")">;
class HLO_ConstantLike<string value> : NativeCodeCall<
"chlo::getConstantLike($_builder, " # value # ", $0)">;
def NullDenseIntElementsAttr : NativeCodeCall<"DenseIntElementsAttr()">;
def BinBroadcastDimensions : NativeCodeCall<

View File

@ -66,6 +66,8 @@ def LHLO_PredOrIntBuffer : MemRefOf<[HLO_Int, HLO_Pred]>;
def LHLO_Buffer : MemRefOf<[AnyFloat, AnySignlessInteger, AnyComplex]>;
def LHLO_ExtentBuffer : MemRefRankOf<[AnySignlessInteger, Index], [1]>;
//===----------------------------------------------------------------------===//
// LMHLO nullary op definitions.
//===----------------------------------------------------------------------===//
@ -467,7 +469,7 @@ def ReshapeMemRefCastOp: Op<LHLO_Dialect, "reshape_memref_cast", [
let arguments = (ins
AnyRankedOrUnrankedMemRef:$operand,
MemRefRankOf<[AnySignlessInteger], [1]>:$shape
LHLO_ExtentBuffer:$shape
);
let results = (outs AnyRankedOrUnrankedMemRef:$result);

View File

@ -0,0 +1,23 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
set(LLVM_TARGET_DEFINITIONS mhlo_passes.td)
mlir_tablegen(mhlo_passes.h.inc -gen-pass-decls -name MHLO)
add_public_tablegen_target(MLIRMhloPassIncGen)
set(LLVM_TARGET_DEFINITIONS lmhlo_passes.td)
mlir_tablegen(lmhlo_passes.h.inc -gen-pass-decls -name LMHLO)
add_public_tablegen_target(MLIRLmhloPassIncGen)

View File

@ -38,10 +38,12 @@ bool IsLegalNumpyRankedBroadcast(Value lhs, Value rhs,
// Emits shape dialect ops to compute the result shape for a broadcasting
// binary elementwise op which broadcasts according to "numpy" semantics
// (see above), returning an extents tensor of the resulting shape.
Value ComputeBinaryElementwiseBroadcastingResultExtents(Location loc, Value lhs,
Value rhs,
OpBuilder& builder);
// (see above), returning a `shape.shape` or an extent tensor of the resulting
// shape. The result should only be an extent tensor in contexts that ensure
// both operands to be broadcastable.
Value ComputeBinaryElementwiseBroadcastingResultExtents(
Location loc, Value lhs, Value rhs, OpBuilder& builder,
bool unsafe_as_extent_tensor);
} // namespace hlo
} // namespace mlir

View File

@ -0,0 +1,17 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(Dialect)
add_subdirectory(utils)

View File

@ -0,0 +1,16 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(mhlo)

View File

@ -0,0 +1,17 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
add_subdirectory(IR)
add_subdirectory(transforms)

View File

@ -0,0 +1,82 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
include_directories(BEFORE
${CMAKE_CURRENT_BINARY_DIR}
${CMAKE_CURRENT_SOURCE_DIR})
set(LLVM_TARGET_DEFINITIONS hlo_patterns.td)
mlir_tablegen(hlo_patterns.cc.inc -gen-rewriters)
add_public_tablegen_target(MLIRMhloRewriterIncGen)
set(LLVM_TARGET_DEFINITIONS mhlo_canonicalize.td)
mlir_tablegen(mhlo_canonicalize.inc -gen-rewriters)
add_public_tablegen_target(MLIRMhloCanonicalizeIncGen)
add_mlir_dialect_library(ChloDialect
chlo_ops.cc
DEPENDS
MLIRchlo_opsIncGen
)
target_link_libraries(ChloDialect PUBLIC MLIRIR)
add_mlir_library(MhloInferFusibilityOpInterface
infer_fusibility_op_interface.cc
DEPENDS
MLIRinfer_fusibility_op_interfaceIncGen
)
add_mlir_dialect_library(MhloDialect
hlo_ops.cc
DEPENDS
MLIRhlo_opsIncGen
MLIRMhloCanonicalizeIncGen
MLIRMhloRewriterIncGen
MLIRinfer_fusibility_op_interfaceIncGen
)
target_link_libraries(MhloDialect
PUBLIC
MLIRIR
MhloInferFusibilityOpInterface
MLIRMhloUtils
)
add_mlir_dialect_library(LmhloDialect
lhlo_ops.cc
DEPENDS
MLIRlhlo_opsIncGen
)
target_link_libraries(LmhloDialect PUBLIC MLIRIR)
add_mlir_dialect_library(MhloRegisterDialects
init.cc
DEPENDS
MLIRchlo_opsIncGen
MLIRhlo_opsIncGen
MLIRlhlo_opsIncGen
)
target_link_libraries(MhloRegisterDialects
PUBLIC
ChloDialect
MhloDialect
LmhloDialect
)

View File

@ -0,0 +1,30 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This is the canonicalize pattern definition file.
include "mlir/IR/OpBase.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.td"
include "mlir-hlo/Dialect/mhlo/IR/hlo_utils.td"
def UnaryToBinaryEinsumEq : NativeCodeCall<
"$_builder.getStringAttr(\",\" + $0.getValue().str())">;
// Convert UnaryEinsumOp to EinsumOp with two operands with redundant first
// operand.
def UnaryEinsumToEinsum : Pat<
(HLO_UnaryEinsumOp $operand, $equation),
(HLO_EinsumOp (HLO_ConstOp (GetScalarOfType<1> $operand)),
$operand, (UnaryToBinaryEinsumEq $equation))>;

View File

@ -15,10 +15,12 @@ limitations under the License.
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.h"
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "mlir-hlo/utils/broadcast_utils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
@ -151,7 +153,7 @@ LogicalResult ReifyBroadcastBinaryOpReturnTypeShapes(
}
Value computed_shape = hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, builder);
loc, lhs, rhs, builder, /*unsafe_as_extent_tensor=*/false);
if (!computed_shape) return failure();
reifiedReturnShapes.push_back(computed_shape);
return success();
@ -259,6 +261,48 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
#undef BROADCAST_INFER_SHAPE_TYPE_OP_DEFS
#undef BROADCAST_BINARY_OP_DEFS
static LogicalResult Verify(ConstantLikeOp op) {
if (op.value().getType() != op.getType().cast<ShapedType>().getElementType())
return op.emitOpError() << "value's type doesn't match element return type";
return success();
}
LogicalResult ConstantLikeOp::inferReturnTypeComponents(
MLIRContext* context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferedReturnShapes) {
ConstantLikeOp::Adaptor op(operands, attributes);
if (failed(op.verify(location.getValue()))) return failure();
Type element_type = op.value().getType();
Type operand_type = op.operand().getType();
if (operand_type.isa<UnrankedTensorType>()) {
inferedReturnShapes.emplace_back(element_type);
} else {
const auto& shape = operand_type.cast<RankedTensorType>().getShape();
inferedReturnShapes.emplace_back(shape, element_type);
}
return success();
}
struct ConstantLikeToConstant : public OpRewritePattern<ConstantLikeOp> {
using OpRewritePattern<ConstantLikeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConstantLikeOp op,
PatternRewriter& rewriter) const override {
auto op_type = op.operand().getType().cast<ShapedType>();
if (!op_type.hasStaticShape()) return failure();
auto type = RankedTensorType::get(op_type.getShape(), op.value().getType());
ElementsAttr attr = DenseElementsAttr::get(type, op.value());
rewriter.replaceOpWithNewOp<mhlo::ConstOp>(op.getOperation(), attr);
return success();
}
};
void ConstantLikeOp::getCanonicalizationPatterns(
OwningRewritePatternList& results, MLIRContext* context) {
results.insert<ConstantLikeToConstant>(context);
}
#define GET_OP_CLASSES
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"
@ -266,8 +310,7 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp);
// chlo Dialect Constructor
//===----------------------------------------------------------------------===//
HloClientDialect::HloClientDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
void HloClientDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc"

View File

@ -112,37 +112,6 @@ DenseIntElementsAttr BuildSliceLimits(DenseIntElementsAttr start_indices,
// ConstOp
//===----------------------------------------------------------------------===//
static void Print(ConstOp op, OpAsmPrinter* printer) {
// Print op name.
*printer << op.getOperationName();
// Elide attribute value while printing the attribute dictionary.
SmallVector<StringRef, 1> elided_attrs;
elided_attrs.push_back("value");
printer->printOptionalAttrDict(op.getAttrs(), elided_attrs);
*printer << ' ' << op.value();
}
static ParseResult ParseConstOp(OpAsmParser* parser, OperationState* result) {
if (parser->parseOptionalAttrDict(result->attributes)) return failure();
// If colon is not present after attribute dictionary, it should be short form
// and attribute 'value' is outside the dictionary.
if (failed(parser->parseOptionalColon())) {
Attribute value;
if (parser->parseAttribute(value, "value", result->attributes))
return failure();
return parser->addTypeToList(value.getType(), result->types);
}
// Long form should have type of the result after colon.
Type ty;
if (parser->parseType(ty)) return failure();
result->types.push_back(ty);
return success();
}
OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
assert(operands.empty() && "constant has no operands");
@ -340,6 +309,33 @@ void DynamicIotaOp::getCanonicalizationPatterns(
results.insert<DynamicIotaBroadcast>(context);
}
//===----------------------------------------------------------------------===//
// DynamicUpdateSliceOp
//===----------------------------------------------------------------------===//
static LogicalResult Verify(DynamicUpdateSliceOp op) {
OperandRange indices = op.start_indices();
if (indices.size() <= 1) return success();
// Note: start_indices is constrained to Variadic<HLO_ScalarIntTensor>, so it
// is OK to cast indices to ShapedType here.
auto idx_tensor = indices.take_front().front().getType().cast<ShapedType>();
Type first_elem_ty = idx_tensor.getElementType();
Type elem_ty;
for (auto idx : llvm::drop_begin(indices, 1)) {
idx_tensor = idx.getType().cast<ShapedType>();
elem_ty = idx_tensor.getElementType();
if (first_elem_ty != elem_ty) {
return op.emitOpError() << "start indices must have same element type "
"(encountered mismatch: "
<< first_elem_ty << " vs " << elem_ty << ")";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// AbsOp
//===----------------------------------------------------------------------===//
@ -506,6 +502,46 @@ static LogicalResult Verify(TupleOp op) {
return success();
}
namespace {
// Pattern for unpacking and repacking the same tuple.
struct UnpackRepackSameTuple : public OpRewritePattern<TupleOp> {
using OpRewritePattern<TupleOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TupleOp op,
PatternRewriter& rewriter) const override {
if (op.val().empty()) return failure();
Value first_element = op.val().front();
auto first_element_op =
dyn_cast_or_null<GetTupleElementOp>(first_element.getDefiningOp());
if (!first_element_op || first_element_op.indexAttr().getInt() != 0)
return failure();
Value tuple_predecessor = first_element_op.getOperand();
if (tuple_predecessor.getType() != op.getType()) return failure();
for (auto element_and_idx : llvm::enumerate(op.val().drop_front(1))) {
auto element_op = dyn_cast_or_null<GetTupleElementOp>(
element_and_idx.value().getDefiningOp());
if (!element_op ||
element_op.indexAttr().getInt() != element_and_idx.index() + 1 ||
element_op.getOperand() != tuple_predecessor)
return failure();
}
rewriter.replaceOp(op, tuple_predecessor);
return success();
}
};
} // namespace
void TupleOp::getCanonicalizationPatterns(OwningRewritePatternList& results,
MLIRContext* context) {
results.insert<UnpackRepackSameTuple>(context);
}
//===----------------------------------------------------------------------===//
// AllToAllOp
//===----------------------------------------------------------------------===//
@ -708,10 +744,12 @@ static LogicalResult Verify(DynamicBroadcastInDimOp op) {
auto dimSize = operandType.getDimSize(i);
auto resultDimSize = resultType.getDimSize(dimIndex);
if (dimSize != 1 && dimSize != resultDimSize) {
// Note: verifyCompatibleShapes doesn't consider size-1 broadcasting, so we
// add a manual check for this.
if (dimSize != 1 && failed(verifyCompatibleShape(dimSize, resultDimSize))) {
return op.emitOpError(
llvm::formatv("size of operand dimension {0} ({1}) is not equal to "
"1 or size of result dimension {2} ({3})",
llvm::formatv("size of operand dimension {0} ({1}) is not compatible "
"with size of result dimension {2} ({3})",
i, dimSize, dimIndex, resultDimSize));
}
}
@ -2150,7 +2188,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface {
//===----------------------------------------------------------------------===//
MhloDialect::MhloDialect(MLIRContext* context)
: Dialect(getDialectNamespace(), context) {
: Dialect(getDialectNamespace(), context, TypeID::get<MhloDialect>()) {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.cc.inc"

View File

@ -49,7 +49,7 @@ namespace mlir {
namespace lmhlo {
LmhloDialect::LmhloDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context) {
: Dialect(getDialectNamespace(), context, TypeID::get<LmhloDialect>()) {
addOperations<
#define GET_OP_LIST
#include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.cc.inc"

View File

@ -0,0 +1,155 @@
#
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
include_directories(BEFORE
${CMAKE_CURRENT_BINARY_DIR}
${CMAKE_CURRENT_SOURCE_DIR})
set(LLVM_TARGET_DEFINITIONS lower_complex_patterns.td)
mlir_tablegen(generated_lower_complex.inc -gen-rewriters)
add_public_tablegen_target(MLIRMhloLowerComplexIncGen)
set(LLVM_TARGET_DEFINITIONS legalize_to_standard_patterns.td)
mlir_tablegen(generated_legalize_to_standard.inc -gen-rewriters)
add_public_tablegen_target(MLIRMhloLegalizeToStandardIncGen)
add_mlir_library(ChloPasses
chlo_legalize_to_hlo.cc
chlo_legalize_to_hlo_pass.cc
DEPENDS
MLIRhlo_opsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
ChloDialect
MLIRIR
MLIRPass
)
add_mlir_library(MhloPasses
legalize_gather_to_torch_index_select.cc
legalize_tanh_to_approximation.cc
lower_complex.cc
lower_complex_patterns.td
lower_general_dot.cc
materialize_broadcasts.cc
materialize_broadcasts_pass.cc
mhlo_fusion.cc
optimize_mhlo.cc
optimize_mhlo_pass.cc
sink_constants_to_control_flow.cc
test_infer_shaped_type_pass.cc
transform_unranked_hlo.cc
unfuse_batch_norm.cc
unfuse_batch_norm_pass.cc
DEPENDS
MLIRhlo_opsIncGen
MLIRMhloLowerComplexIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRMhloUtils
MLIRPass
MLIRTransformUtils
)
add_mlir_library(MhloToLhloConversion
hlo_legalize_to_lhlo.cc
DEPENDS
MLIRhlo_opsIncGen
MLIRlhlo_opsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MhloDialect
LmhloDialect
MLIRIR
MLIRPass
)
add_mlir_library(MhloToStandard
legalize_control_flow.cc
legalize_to_standard.cc
DEPENDS
MLIRhlo_opsIncGen
MLIRlhlo_opsIncGen
MLIRMhloLegalizeToStandardIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
)
add_mlir_library(MhloLhloToLinalg
legalize_to_linalg.cc
DEPENDS
MLIRhlo_opsIncGen
MLIRlhlo_opsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MhloDialect
MLIRIR
MLIRPass
)
add_mlir_library(LmhloPasses
lhlo_copy_removal.cc
lhlo_fuse_linalg.cc
lhlo_legalize_to_affine.cc
lhlo_legalize_to_gpu.cc
lhlo_legalize_to_llvm.cc
lhlo_legalize_to_llvm_pass.cc
lhlo_legalize_to_parallel_loops.cc
DEPENDS
MLIRlhlo_opsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
LmhloDialect
MLIRIR
MLIRPass
)
add_library(AllMhloPasses INTERFACE)
target_link_libraries(AllMhloPasses INTERFACE
ChloPasses
MhloPasses
MhloToLhloConversion
MhloToStandard
MhloLhloToLinalg
LmhloPasses
)

View File

@ -124,8 +124,8 @@ struct ConvertRankedDynamicBroadcastBinaryOp
int64_t result_rank = std::max(lhs_type.getRank(), rhs_type.getRank());
Value result_extents =
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(loc, lhs, rhs,
rewriter);
hlo::ComputeBinaryElementwiseBroadcastingResultExtents(
loc, lhs, rhs, rewriter, /*unsafe_as_extent_tensor=*/true);
// Note that we unconditionally emit DynamicBroadcastInDim ops and let
// downstream canonicalizations fold them away if possible. This is

View File

@ -36,6 +36,10 @@ def IsSameSizePred : CPred<
def IsSameSizeConstraint : Constraint<IsSameSizePred, "inputs are same size">;
// Unary Lowering Patterns.
def : Pat<(HLO_CeilOp HLO_FpTensor:$i), (CeilFOp $i)>;
// Binary Lowering Patterns.
def : Pat<(HLO_AndOp HLO_PredTensor:$l, HLO_PredTensor:$r),
(AndOp $l, $r),
[(IsSameSizeConstraint $l, $r)]>;

Some files were not shown because too many files have changed in this diff Show More