Merge branch 'master' into interface_16x8
This commit is contained in:
commit
b6c7284053
tensorflow
c
c_api_experimental.cc
eager
BUILDc_api.ccc_api_experimental.ccc_api_experimental_test.ccc_api_internal.hc_api_remote_test.ccc_api_test.ccc_api_unified_experimental.ccc_api_unified_experimental.hc_api_unified_experimental_test.cccontext_interface.cccontext_interface.hdlpack.ccoperation_interface.ccoperation_interface.htensor_handle_interface.h
compiler
mlir
xla
core
BUILD
common_runtime
bfc_allocator.ccbfc_allocator.hdirect_session.cc
eager
context.ccexecute.ccprocess_function_library_runtime.cctensor_handle.cctensor_handle.htensor_handle_test.cc
executor.ccfunction_test.ccfunction_threadpool_test.ccgraph_view.ccgraph_view.hmetrics.ccprocess_function_library_runtime.ccprocess_function_library_runtime.hprocess_function_library_runtime_test.ccdata
distributed_runtime
framework
kernels
batch_kernels.cc
data
function_ops.ccin_topk_op_gpu.cu.ccops_testutil.ccsparse
csr_sparse_matrix_to_dense_op.ccdense_to_csr_sparse_matrix_op.ccmat_mul_op.ccmul_op.ccsparse_matrix_components_op.ccsparse_tensor_to_csr_sparse_matrix_op.cc
variable_ops.ccvariable_ops.hprofiler
go/op
lite/delegates
flex
gpu
@ -683,7 +683,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
|
||||
|
||||
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
|
||||
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
|
||||
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
||||
|
||||
status->status = tensorflow::Status::OK();
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -28,6 +28,7 @@ tf_cuda_library(
|
||||
"c_api_debug.cc",
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.cc",
|
||||
"context_interface.h",
|
||||
"operation_interface.cc",
|
||||
@ -64,6 +65,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -97,6 +99,7 @@ filegroup(
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_internal.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"dlpack.h",
|
||||
"operation_interface.h",
|
||||
@ -112,6 +115,7 @@ tf_cuda_library(
|
||||
name = "c_api_internal",
|
||||
srcs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
"context_interface.h",
|
||||
"operation_interface.h",
|
||||
"tensor_handle_interface.h",
|
||||
@ -210,7 +214,6 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
@ -219,8 +222,12 @@ tf_cuda_library(
|
||||
name = "c_api_experimental",
|
||||
srcs = [
|
||||
"c_api_experimental.cc",
|
||||
"c_api_unified_experimental.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"c_api_experimental.h",
|
||||
"c_api_unified_experimental.h",
|
||||
],
|
||||
hdrs = ["c_api_experimental.h"],
|
||||
copts = tf_copts() + tfe_xla_copts(),
|
||||
visibility = ["//visibility:public"],
|
||||
deps = select({
|
||||
@ -246,6 +253,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:variant",
|
||||
],
|
||||
}) + select({
|
||||
"//tensorflow:with_xla_support": [
|
||||
@ -297,6 +305,30 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_unified_experimental_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"c_api_unified_experimental_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "custom_device_test",
|
||||
size = "small",
|
||||
|
@ -919,7 +919,10 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
|
||||
tensorflow::Tensor tensor;
|
||||
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
|
||||
if (!status->status.ok()) return nullptr;
|
||||
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
|
||||
}
|
||||
|
||||
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
|
||||
@ -1074,10 +1077,12 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
|
||||
}
|
||||
|
||||
return new TFE_TensorHandle{
|
||||
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
|
||||
std::unique_ptr<tensorflow::AbstractTensorHandleInterface>(
|
||||
h->handle->Copy())};
|
||||
}
|
||||
|
||||
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
|
||||
tensorflow::AbstractTensorHandleInterface*
|
||||
tensorflow::TensorHandleInterface::Copy() {
|
||||
handle_->Ref();
|
||||
return new TensorHandleInterface(handle_);
|
||||
}
|
||||
@ -1166,8 +1171,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (VariantDeviceIsCustom(handle->device())) {
|
||||
const tensorflow::Tensor* t;
|
||||
status->status = handle->Tensor(&t);
|
||||
@ -1228,19 +1232,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
|
||||
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
|
||||
tensorflow::TensorShape(dimvec), buf);
|
||||
buf->Unref();
|
||||
tensorflow::TensorHandle* ret_handle;
|
||||
if (custom_device == nullptr) {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), device, device, context, &ret_handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(std::move(t), device,
|
||||
device, context))};
|
||||
} else {
|
||||
status->status = tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context, &ret_handle);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(
|
||||
std::move(t), custom_device, context))};
|
||||
}
|
||||
if (!status->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
|
||||
}
|
||||
|
||||
// This function will block till the operation that produces `h` has
|
||||
@ -1254,9 +1256,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
|
||||
return 0;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
|
||||
@ -1309,8 +1309,8 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
|
||||
|
||||
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
num_inputs);
|
||||
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
|
||||
handles(num_inputs);
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
handles[i].reset(inputs[i]->handle->Copy());
|
||||
}
|
||||
@ -1504,8 +1504,8 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
|
||||
|
||||
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
TF_Status* status) {
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
|
||||
*num_retvals);
|
||||
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
|
||||
handles(*num_retvals);
|
||||
status->status = op->operation->Execute(&handles, num_retvals);
|
||||
if (!status->status.ok()) {
|
||||
return;
|
||||
@ -1529,10 +1529,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
status->status = context->FindCustomDeviceFromName(device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
&handle);
|
||||
tensorflow::TensorHandleFromInterface(h->handle), &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
@ -1549,10 +1546,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
|
||||
if (status->status.ok()) {
|
||||
status->status = dev->CopyTensorFromDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h->handle.get())
|
||||
->Handle(),
|
||||
device_name, &handle);
|
||||
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
@ -1562,9 +1556,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
|
||||
|
||||
// Handle regular case.
|
||||
status->status = tensorflow::EagerCopyToDevice(
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle(),
|
||||
context, &context->Executor(), device, false, &handle);
|
||||
tensorflow::TensorHandleFromInterface(h->handle), context,
|
||||
&context->Executor(), device, false, &handle);
|
||||
if (status->status.ok()) {
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
@ -1622,7 +1615,9 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
|
||||
|
||||
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
|
||||
TF_Status* status) {
|
||||
return TFE_TensorHandle::CreateLocalHandle(t, status);
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(
|
||||
tensorflow::TensorHandle::CreateLocalHandle(t))};
|
||||
}
|
||||
|
||||
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
|
||||
@ -1767,9 +1762,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_TensorHandle* result_handle =
|
||||
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
@ -1786,9 +1779,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
|
||||
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
|
||||
if (!status.status.ok()) return status.status;
|
||||
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
result_handle->handle.get())
|
||||
->Handle();
|
||||
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
|
||||
(*result)->Ref();
|
||||
delete result_handle;
|
||||
return status.status;
|
||||
@ -1812,9 +1803,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
|
||||
&attributes, num_retvals, outputs.data(), &status, info_);
|
||||
if (status.status.ok()) {
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
outputs[i]->handle.get())
|
||||
->Handle();
|
||||
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
|
||||
retvals[i]->Ref();
|
||||
delete outputs[i];
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ using tensorflow::string;
|
||||
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
|
||||
const char* raw_device_name, TF_Status* status) {
|
||||
if (op_to_reset) {
|
||||
op_to_reset->operation->Clear();
|
||||
status->status =
|
||||
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
|
||||
} else {
|
||||
|
@ -455,6 +455,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
|
||||
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
|
||||
TFE_DeleteTensorHandle(on_host);
|
||||
}
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(m_data);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
TFE_DeleteContext(ctx);
|
||||
|
@ -69,18 +69,7 @@ struct TFE_Context {
|
||||
};
|
||||
|
||||
struct TFE_TensorHandle {
|
||||
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
|
||||
TF_Status* s) {
|
||||
tensorflow::TensorHandle* handle;
|
||||
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
|
||||
if (!s->status.ok()) {
|
||||
return nullptr;
|
||||
}
|
||||
return new TFE_TensorHandle{
|
||||
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractTensorHandleInterface> handle;
|
||||
std::unique_ptr<tensorflow::AbstractTensorHandleInterface> handle;
|
||||
};
|
||||
|
||||
struct TFE_TensorDebugInfo {
|
||||
@ -92,7 +81,7 @@ struct TFE_TensorDebugInfo {
|
||||
};
|
||||
|
||||
struct TFE_Op {
|
||||
std::unique_ptr<AbstractOperationInterface> operation;
|
||||
std::unique_ptr<tensorflow::AbstractOperationInterface> operation;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
|
@ -184,9 +184,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
|
||||
|
||||
// TODO(gjn): Add support for waiting on async local mirrors
|
||||
if (!async) {
|
||||
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
h1_task2->handle.get())
|
||||
->Handle();
|
||||
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
// The input handles should never change since they have been mirrored.
|
||||
|
@ -409,13 +409,8 @@ void TensorHandleSilentCopy(bool async,
|
||||
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
|
||||
|
||||
// Validate if the input was replaced with a different TensorHandle
|
||||
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hcpu->handle.get())
|
||||
->Handle();
|
||||
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
|
||||
hgpu->handle.get())
|
||||
->Handle();
|
||||
|
||||
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
|
||||
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
|
||||
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
|
||||
matmul->operation.get());
|
||||
|
||||
|
261
tensorflow/c/eager/c_api_unified_experimental.cc
Normal file
261
tensorflow/c/eager/c_api_unified_experimental.cc
Normal file
@ -0,0 +1,261 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/lib/monitoring/counter.h"
|
||||
#include "tensorflow/core/lib/monitoring/gauge.h"
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// =============================================================================
|
||||
|
||||
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs,
|
||||
TF_OutputList* o, TF_ExecutionContext* ctx,
|
||||
TF_Status* s);
|
||||
struct TF_ExecutionContext {
|
||||
explicit TF_ExecutionContext() {}
|
||||
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
|
||||
ExecuteOperation execution_callback;
|
||||
};
|
||||
|
||||
struct TF_AbstractTensor {
|
||||
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
|
||||
};
|
||||
|
||||
struct TF_AbstractOp {
|
||||
string op_type;
|
||||
string op_name;
|
||||
};
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext() {
|
||||
return new TF_ExecutionContext();
|
||||
}
|
||||
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp() {
|
||||
TF_AbstractOp* op = new TF_AbstractOp;
|
||||
return op;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor() {
|
||||
TF_AbstractTensor* t = new TF_AbstractTensor;
|
||||
return t;
|
||||
}
|
||||
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
|
||||
|
||||
struct TF_GraphContext {
|
||||
TF_Graph* graph;
|
||||
// TODO(srbs): Handle captures.
|
||||
};
|
||||
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
|
||||
auto ctx = new TF_GraphContext;
|
||||
ctx->graph = g;
|
||||
return ctx;
|
||||
}
|
||||
|
||||
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
|
||||
|
||||
struct TF_GraphTensor {
|
||||
TF_Output output;
|
||||
TF_GraphContext* ctx;
|
||||
};
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
|
||||
TF_Status* s) {
|
||||
TF_GraphTensor* t = new TF_GraphTensor;
|
||||
t->output = output;
|
||||
t->ctx = ctx;
|
||||
return t;
|
||||
}
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
|
||||
return t->output;
|
||||
}
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an eager tensor handle.",
|
||||
reinterpret_cast<uintptr_t>(at));
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TFE_TensorHandle*>(at->t);
|
||||
}
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s) {
|
||||
at->t = t;
|
||||
}
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s) {
|
||||
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
|
||||
string msg = absl::StrCat("Not an graph tensor handle.");
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return absl::get<TF_GraphTensor*>(at->t);
|
||||
}
|
||||
|
||||
bool IsEagerTensor(const TF_AbstractTensor* const t) {
|
||||
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
|
||||
}
|
||||
|
||||
struct TF_OutputList {
|
||||
std::vector<TF_AbstractTensor*> outputs;
|
||||
int expected_num_outputs = -1;
|
||||
};
|
||||
|
||||
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
|
||||
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
|
||||
TF_Status* s) {
|
||||
o->expected_num_outputs = num_outputs;
|
||||
}
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
|
||||
return o->outputs[i];
|
||||
}
|
||||
|
||||
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
auto* tfe_op =
|
||||
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
if (!IsEagerTensor(inputs[i])) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
|
||||
return;
|
||||
}
|
||||
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
}
|
||||
if (o->expected_num_outputs == -1) {
|
||||
string msg =
|
||||
"The number of outputs must be provided in eager mode. Use "
|
||||
"TF_OutputListSetNumOutputs.";
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||
return;
|
||||
}
|
||||
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
|
||||
int num_retvals = o->expected_num_outputs;
|
||||
retvals.resize(num_retvals);
|
||||
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
|
||||
TFE_DeleteOp(tfe_op);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_retvals);
|
||||
for (int i = 0; i < num_retvals; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
t->t = retvals[i];
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
|
||||
return absl::get<TF_GraphTensor*>(t->t)->ctx;
|
||||
}
|
||||
|
||||
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
|
||||
TF_Graph* g = graph_ctx->graph;
|
||||
auto* tf_opdesc =
|
||||
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
|
||||
for (int i = 0; i < num_inputs; ++i) {
|
||||
auto* input = inputs[i];
|
||||
if (IsEagerTensor(input)) {
|
||||
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||
"Capturing eager tensors is not supported yet.");
|
||||
return;
|
||||
} else {
|
||||
if (GetGraphContext(input) != graph_ctx) {
|
||||
TF_SetStatus(
|
||||
s, TF_INVALID_ARGUMENT,
|
||||
"Capturing tensors from other graphs is not supported yet.");
|
||||
return;
|
||||
}
|
||||
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
|
||||
}
|
||||
}
|
||||
auto* operation = TF_FinishOperation(tf_opdesc, s);
|
||||
if (TF_GetCode(s) != TF_OK) return;
|
||||
int num_outputs = TF_OperationNumOutputs(operation);
|
||||
o->outputs.clear();
|
||||
o->outputs.reserve(num_outputs);
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
auto* t = TF_NewAbstractTensor();
|
||||
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
|
||||
if (TF_GetCode(s) != TF_OK) {
|
||||
return;
|
||||
}
|
||||
t->t = output_t;
|
||||
o->outputs.push_back(t);
|
||||
}
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = eager_context;
|
||||
context->execution_callback = &ExecuteOperationEager;
|
||||
}
|
||||
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status* s) {
|
||||
context->ctx = graph_context;
|
||||
context->execution_callback = &ExecuteOperationGraph;
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s) {
|
||||
op->op_type = op_type;
|
||||
}
|
||||
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s) {
|
||||
op->op_name = op_name;
|
||||
}
|
||||
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s) {
|
||||
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
|
||||
}
|
119
tensorflow/c/eager/c_api_unified_experimental.h
Normal file
119
tensorflow/c/eager/c_api_unified_experimental.h
Normal file
@ -0,0 +1,119 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
||||
|
||||
#include "tensorflow/c/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// =============================================================================
|
||||
// Unified Execution APIs for Eager and tracing backends.
|
||||
// =============================================================================
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Core APIs
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// A TF_ExecutionContext stores knowledge about how to execute an operation.
|
||||
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
|
||||
// of gradient tapes, etc.
|
||||
typedef struct TF_ExecutionContext TF_ExecutionContext;
|
||||
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
|
||||
// type of eager and graph tensors.
|
||||
typedef struct TF_AbstractTensor TF_AbstractTensor;
|
||||
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
|
||||
// could contain the op type and other attributes.
|
||||
typedef struct TF_AbstractOp TF_AbstractOp;
|
||||
|
||||
TF_ExecutionContext* TF_NewExecutionContext();
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
TF_AbstractOp* TF_NewAbstractOp();
|
||||
void TF_DeleteAbstractOp(TF_AbstractOp*);
|
||||
|
||||
TF_AbstractTensor* TF_NewAbstractTensor();
|
||||
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// APIs for Eager and graph modes
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// Keeps track of the current graph and other state e.g. captures etc.
|
||||
typedef struct TF_GraphContext TF_GraphContext;
|
||||
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
|
||||
void TF_DeleteGraphContext(TF_GraphContext*);
|
||||
|
||||
// `eager_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
|
||||
TFE_Context* eager_context, TF_Status*);
|
||||
// `graph_context` must outlive `context`.
|
||||
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
|
||||
TF_GraphContext* graph_context,
|
||||
TF_Status*);
|
||||
|
||||
// TODO(srbs): Add APIs for specifying attrs etc.
|
||||
// `op_type` must outlive `op`.
|
||||
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
|
||||
TF_Status* s);
|
||||
// `op_name` must outlive `op`.
|
||||
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
|
||||
TF_Status* s);
|
||||
|
||||
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
|
||||
typedef struct TF_GraphTensor TF_GraphTensor;
|
||||
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
|
||||
TF_Status* s);
|
||||
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
|
||||
void TF_DeleteGraphTensor(TF_GraphTensor* t);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
|
||||
TF_Status* s);
|
||||
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// `t` must outlive `at`.
|
||||
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
|
||||
TF_Status* s);
|
||||
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
|
||||
TF_Status* s);
|
||||
|
||||
// TF_OutputList just lets us not specify the number of outputs of an operation
|
||||
// beforehand. This forces a memory allocation in the runtime, which is bad, but
|
||||
// it allows for generic code.
|
||||
typedef struct TF_OutputList TF_OutputList;
|
||||
TF_OutputList* TF_NewOutputList();
|
||||
void TF_DeleteOutputList(TF_OutputList* o);
|
||||
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
|
||||
int TF_OutputListNumOutputs(TF_OutputList* o);
|
||||
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
|
||||
|
||||
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
|
||||
// capture some inputs and then add a node in the graph, and after
|
||||
// execution/node creation it'll go and record things that happened in any tape
|
||||
// which happens to be active.
|
||||
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||
TF_AbstractTensor* const* inputs, TF_OutputList* o,
|
||||
TF_ExecutionContext* ctx, TF_Status* s);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} /* end extern "C" */
|
||||
#endif
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
|
204
tensorflow/c/eager/c_api_unified_experimental_test.cc
Normal file
204
tensorflow/c/eager/c_api_unified_experimental_test.cc
Normal file
@ -0,0 +1,204 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/cc/profiler/profiler.h"
|
||||
#include "tensorflow/core/lib/monitoring/collection_registry.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(UnifedCAPI, TestBasicEager) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
|
||||
TF_AbstractTensor* at = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(at, t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {at, at};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(at);
|
||||
TFE_DeleteTensorHandle(t);
|
||||
|
||||
// Verify the results.
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
float* result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 4.0);
|
||||
|
||||
TF_DeleteTensor(result_tensor);
|
||||
TF_DeleteAbstractTensor(result);
|
||||
TFE_DeleteTensorHandle(result_t);
|
||||
TF_DeleteOutputList(o);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
TEST(UnifedCAPI, TestBasicGraph) {
|
||||
TF_ExecutionContext* ctx = TF_NewExecutionContext();
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
// Enter a graph context.
|
||||
TF_Graph* g = TF_NewGraph();
|
||||
TF_GraphContext* graph_context = TF_NewGraphContext(g);
|
||||
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
|
||||
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
|
||||
auto* operation = TF_FinishOperation(placeholder_op, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output placeholder_t = {operation, 0};
|
||||
TF_GraphTensor* graph_t =
|
||||
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractTensor* t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
auto* op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(op, "Add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOpSetOpName(op, "my_add", status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build inputs and outputs.
|
||||
TF_AbstractTensor* inputs[2] = {t, t};
|
||||
TF_OutputList* o = TF_NewOutputList();
|
||||
|
||||
// Execute.
|
||||
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Clean up operation and inputs.
|
||||
TF_DeleteAbstractOp(op);
|
||||
TF_DeleteAbstractTensor(t);
|
||||
TF_DeleteGraphTensor(graph_t);
|
||||
|
||||
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
|
||||
TF_GraphTensor* result_graph_tensor =
|
||||
TF_AbstractTensorGetGraphTensor(result, status.get());
|
||||
TF_DeleteAbstractTensor(result);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Output result_output =
|
||||
TF_GraphTensorToOutput(result_graph_tensor, status.get());
|
||||
TF_DeleteGraphTensor(result_graph_tensor);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
string fn_name = "double";
|
||||
TF_Function* f = TF_GraphToFunction(
|
||||
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
|
||||
nullptr, nullptr, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an eager context to run the function.
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
// Build the abstract op to run the function.
|
||||
TFE_ContextAddFunction(eager_ctx, f, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_AbstractOp* fn_op = TF_NewAbstractOp();
|
||||
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract input tensor.
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
|
||||
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
|
||||
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Enter the eager context.
|
||||
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_OutputListSetNumOutputs(o, 1, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
|
||||
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
|
||||
TFE_TensorHandle* final =
|
||||
TF_AbstractTensorGetEagerTensor(final_result, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
float* f_value = static_cast<float*>(TF_TensorData(f_t));
|
||||
ASSERT_EQ(*f_value, 4.0);
|
||||
|
||||
TF_DeleteOutputList(o);
|
||||
TF_DeleteAbstractOp(fn_op);
|
||||
TF_DeleteAbstractTensor(input_t);
|
||||
TFE_DeleteTensorHandle(input_eager);
|
||||
TF_DeleteAbstractTensor(final_result);
|
||||
TFE_DeleteTensorHandle(final);
|
||||
TF_DeleteTensor(f_t);
|
||||
TF_DeleteFunction(f);
|
||||
|
||||
TF_DeleteGraphContext(graph_context);
|
||||
TF_DeleteGraph(g);
|
||||
TFE_DeleteContext(eager_ctx);
|
||||
TF_DeleteExecutionContext(ctx);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -121,20 +121,13 @@ std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
|
||||
Tensor(DT_BOOL, TensorShape(dim_sizes)));
|
||||
}
|
||||
|
||||
Status ContextInterface::CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t,
|
||||
std::unique_ptr<AbstractTensorHandleInterface>* h) {
|
||||
std::unique_ptr<AbstractTensorHandleInterface>
|
||||
ContextInterface::CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) {
|
||||
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
|
||||
tensorflow::TensorHandle* handle = nullptr;
|
||||
auto status =
|
||||
return std::make_unique<TensorHandleInterface>(
|
||||
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
|
||||
/*op_device=*/nullptr, ctx_, &handle);
|
||||
if (!status.ok()) {
|
||||
return status;
|
||||
}
|
||||
*h = std::make_unique<TensorHandleInterface>(handle);
|
||||
|
||||
return status;
|
||||
/*op_device=*/nullptr, ctx_));
|
||||
}
|
||||
|
||||
std::unique_ptr<AbstractOperationInterface>
|
||||
|
@ -75,16 +75,14 @@ class AbstractContextInterface {
|
||||
absl::Span<const int64> dim_sizes) = 0;
|
||||
|
||||
// Create a handle to wrap and manage a Tensor
|
||||
virtual tensorflow::Status CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t,
|
||||
std::unique_ptr<AbstractTensorHandleInterface>* handle) = 0;
|
||||
virtual std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) = 0;
|
||||
|
||||
// Create an operation to perform op execution
|
||||
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
|
||||
|
||||
// List attributes of available devices
|
||||
virtual void ListDevices(
|
||||
std::vector<tensorflow::DeviceAttributes>* devices) = 0;
|
||||
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
|
||||
};
|
||||
|
||||
// TODO(gjn): Try to move these all to EagerContext and make it implement
|
||||
@ -133,12 +131,11 @@ class ContextInterface : public AbstractContextInterface {
|
||||
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
|
||||
absl::Span<const int64> dim_sizes) override;
|
||||
|
||||
tensorflow::Status CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t,
|
||||
std::unique_ptr<AbstractTensorHandleInterface>* h) override;
|
||||
std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
|
||||
const std::unique_ptr<AbstractTensorInterface> t) override;
|
||||
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
|
||||
|
||||
void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
|
||||
void ListDevices(std::vector<DeviceAttributes>* devices) override;
|
||||
|
||||
// For runtime specific APIs, provide ability to get the underlying context.
|
||||
EagerContext* Context() const { return ctx_; }
|
||||
@ -149,7 +146,7 @@ class ContextInterface : public AbstractContextInterface {
|
||||
|
||||
inline EagerContext* ContextFromInterface(
|
||||
const std::unique_ptr<AbstractContextInterface>& context) {
|
||||
return down_cast<tensorflow::ContextInterface*>(context.get())->Context();
|
||||
return down_cast<ContextInterface*>(context.get())->Context();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_reference.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -47,9 +46,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
|
||||
->Handle();
|
||||
|
||||
tensorflow::TensorHandleFromInterface(h->handle);
|
||||
if (handle->IsRemote()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
"DLPack doesn't support remote tensor");
|
||||
|
@ -266,8 +266,7 @@ Status OperationInterface::OutputLength(const char* output_name, int* length) {
|
||||
|
||||
Status OperationInterface::AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
operation_.AddInput(h);
|
||||
return operation_.MaybeInferSingleInputAttrs(h);
|
||||
}
|
||||
@ -276,8 +275,7 @@ Status OperationInterface::AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) {
|
||||
for (auto& input : inputs) {
|
||||
TensorHandle* h =
|
||||
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
|
||||
TensorHandle* h = TensorHandleFromInterface(input);
|
||||
operation_.AddInput(h);
|
||||
}
|
||||
return operation_.InferInputListAttrs(inputs.size());
|
||||
|
@ -23,91 +23,75 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/tensor_handle_interface.h"
|
||||
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to an operation.
|
||||
class AbstractOperationInterface {
|
||||
public:
|
||||
virtual ~AbstractOperationInterface() {}
|
||||
|
||||
virtual void Clear() = 0;
|
||||
virtual tensorflow::Status Reset(const char* op,
|
||||
const char* raw_device_name) = 0;
|
||||
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
|
||||
|
||||
virtual const tensorflow::string& Name() const = 0;
|
||||
virtual const tensorflow::string& DeviceName() const = 0;
|
||||
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
|
||||
virtual const string& Name() const = 0;
|
||||
virtual const string& DeviceName() const = 0;
|
||||
virtual Status SetDeviceName(const char* name) = 0;
|
||||
|
||||
virtual tensorflow::Status AddInput(
|
||||
virtual Status AddInput(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
|
||||
virtual tensorflow::Status AddInputList(
|
||||
virtual Status AddInputList(
|
||||
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
|
||||
inputs) = 0;
|
||||
virtual tensorflow::Status Execute(
|
||||
virtual Status Execute(
|
||||
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
|
||||
int* num_retvals) = 0;
|
||||
virtual const tensorflow::OpDef* OpDef() const = 0;
|
||||
|
||||
virtual tensorflow::Status SetAttrString(const char* attr_name,
|
||||
const char* data, size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrInt(const char* attr_name,
|
||||
int64_t value) = 0;
|
||||
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
|
||||
float value) = 0;
|
||||
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual tensorflow::Status SetAttrType(const char* attr_name,
|
||||
TF_DataType value) = 0;
|
||||
virtual tensorflow::Status SetAttrShape(const char* attr_name,
|
||||
const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual tensorflow::Status SetAttrFunction(
|
||||
virtual Status SetAttrString(const char* attr_name, const char* data,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
|
||||
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
|
||||
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
|
||||
virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0;
|
||||
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
|
||||
const int num_dims) = 0;
|
||||
virtual Status SetAttrFunction(
|
||||
const char* attr_name,
|
||||
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
|
||||
const char* value,
|
||||
size_t length) = 0;
|
||||
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
|
||||
TF_Tensor* tensor) = 0;
|
||||
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
|
||||
const float* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
|
||||
const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
|
||||
const int64_t** dims,
|
||||
const int* num_dims,
|
||||
int num_values) = 0;
|
||||
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
|
||||
size_t length) = 0;
|
||||
virtual Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) = 0;
|
||||
virtual Status SetAttrStringList(const char* attr_name,
|
||||
const void* const* values,
|
||||
const size_t* lengths, int num_values) = 0;
|
||||
virtual Status SetAttrFloatList(const char* attr_name, const float* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrTypeList(const char* attr_name,
|
||||
const TF_DataType* values, int num_values) = 0;
|
||||
virtual Status SetAttrBoolList(const char* attr_name,
|
||||
const unsigned char* values,
|
||||
int num_values) = 0;
|
||||
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
|
||||
const int* num_dims, int num_values) = 0;
|
||||
virtual Status SetAttrFunctionList(const char* attr_name,
|
||||
const TFE_Op** value, int num_values) = 0;
|
||||
|
||||
virtual tensorflow::Status InputLength(const char* input_name,
|
||||
int* length) = 0;
|
||||
virtual tensorflow::Status OutputLength(const char* output_name,
|
||||
int* length) = 0;
|
||||
virtual Status InputLength(const char* input_name, int* length) = 0;
|
||||
virtual Status OutputLength(const char* output_name, int* length) = 0;
|
||||
|
||||
// Experimental
|
||||
virtual tensorflow::Status SetUseXla(bool enable) {
|
||||
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
|
||||
virtual Status SetUseXla(bool enable) {
|
||||
return errors::Unimplemented("SetUseXla not implemented");
|
||||
}
|
||||
virtual tensorflow::Status SetCancellationManager(
|
||||
|
||||
virtual Status SetCancellationManager(
|
||||
TFE_CancellationManager* cancellation_manager) {
|
||||
return tensorflow::errors::Unimplemented(
|
||||
"SetCancellationManager not implemented");
|
||||
return errors::Unimplemented("SetCancellationManager not implemented");
|
||||
}
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpDef;
|
||||
|
||||
class OperationInterface : public AbstractOperationInterface {
|
||||
@ -173,8 +157,8 @@ class OperationInterface : public AbstractOperationInterface {
|
||||
TFE_CancellationManager* cancellation_manager) override;
|
||||
|
||||
// TODO(gjn): Remove once TFE_InferShapes is removed
|
||||
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
|
||||
AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
|
||||
|
||||
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }
|
||||
|
||||
|
@ -19,6 +19,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a TensorHandle.
|
||||
//
|
||||
@ -34,24 +37,24 @@ class AbstractTensorHandleInterface {
|
||||
virtual ~AbstractTensorHandleInterface() {}
|
||||
|
||||
// Check if the handle is in a valid initialized state.
|
||||
virtual bool IsValid(tensorflow::Status* status) const = 0;
|
||||
virtual bool IsValid(Status* status) const = 0;
|
||||
// Returns tensor dtype.
|
||||
virtual TF_DataType DataType() const = 0;
|
||||
// Returns number of dimensions.
|
||||
virtual int NumDims(tensorflow::Status* status) const = 0;
|
||||
virtual int NumDims(Status* status) const = 0;
|
||||
// Returns number of elements across all dimensions.
|
||||
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
|
||||
virtual int64_t NumElements(Status* status) const = 0;
|
||||
// Returns size of specified dimension
|
||||
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
|
||||
virtual int64_t Dim(int dim_index, Status* status) const = 0;
|
||||
|
||||
// Returns the device which created the handle.
|
||||
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
|
||||
virtual const char* DeviceName(Status* status) const = 0;
|
||||
// Returns the device where the tensor was placed.
|
||||
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
|
||||
virtual const char* BackingDeviceName(Status* status) const = 0;
|
||||
// Returns a tensor for the handle. If tensor is remote, it will be copied.
|
||||
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
|
||||
virtual TF_Tensor* Resolve(Status* status) = 0;
|
||||
// Returns debug information about the tensor.
|
||||
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
|
||||
virtual TFE_TensorDebugInfo* TensorDebugInfo(Status* status) = 0;
|
||||
|
||||
// Return a copy of the handle.
|
||||
virtual AbstractTensorHandleInterface* Copy() = 0;
|
||||
@ -65,8 +68,9 @@ class AbstractTensorHandleInterface {
|
||||
virtual void EnableImplicitMirroring() = 0;
|
||||
};
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(gjn): Try to move these all to TensorHandle and make it implement
|
||||
// AbstractTensorHandleInterface. Currently, this is not so straightforward
|
||||
// because of various BUILD file dependencies.
|
||||
class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
public:
|
||||
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
|
||||
@ -87,14 +91,18 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
|
||||
|
||||
void EnableImplicitMirroring() override;
|
||||
|
||||
// TODO(gjn): This is not a very generic interface, but is needed for specific
|
||||
// use cases.
|
||||
// For runtime specific APIs, provide ability to get the underlying handle.
|
||||
TensorHandle* Handle() { return handle_; }
|
||||
|
||||
private:
|
||||
TensorHandle* handle_;
|
||||
};
|
||||
|
||||
inline TensorHandle* TensorHandleFromInterface(
|
||||
const std::unique_ptr<AbstractTensorHandleInterface>& handle) {
|
||||
return down_cast<TensorHandleInterface*>(handle.get())->Handle();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_
|
||||
|
@ -42,6 +42,20 @@ class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.concatenation
|
||||
template <>
|
||||
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): We probably need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.conv_2d
|
||||
template <>
|
||||
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
|
||||
@ -69,6 +83,20 @@ class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.fully_connected
|
||||
template <>
|
||||
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// TODO(renjieliu): we need to check for dynamic weights.
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.logistic
|
||||
template <>
|
||||
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
|
||||
@ -95,6 +123,19 @@ class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mirror_pad
|
||||
template <>
|
||||
class TFLiteCostEstimator<MirrorPadOp, hardware::GPU> {
|
||||
public:
|
||||
static double GetCost(mlir::Operation* op) {
|
||||
llvm::errs() << "No defined cost function for op: "
|
||||
<< op->getName().getStringRef().str();
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
static bool IsSupported(mlir::Operation* op) { return true; }
|
||||
};
|
||||
|
||||
// tfl.mul
|
||||
template <>
|
||||
class TFLiteCostEstimator<MulOp, hardware::GPU> {
|
||||
|
@ -579,7 +579,8 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
|
||||
NoSideEffect,
|
||||
PredOpTrait<"values and output must have same element type",
|
||||
TFL_TCresVTEtIsSameAsOp<0, 0>>,
|
||||
SameOperandsAndResultsScale
|
||||
SameOperandsAndResultsScale,
|
||||
TFL_GpuTargetOp
|
||||
]> {
|
||||
let summary = "Concatenation operator";
|
||||
|
||||
@ -754,7 +755,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
|
||||
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
|
||||
TFL_ChannelDimIndexInterface,
|
||||
AffineOpCoefficient<-1, 1>,
|
||||
TFL_SparseOp]> {
|
||||
TFL_SparseOp,
|
||||
TFL_GpuTargetOp]> {
|
||||
let summary = "Fully connected op";
|
||||
|
||||
let arguments = (ins
|
||||
@ -2912,7 +2914,7 @@ def TFL_CastOp : TFL_Op<"cast", [
|
||||
|
||||
|
||||
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
|
||||
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
|
||||
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
|
||||
|
||||
let description = [{
|
||||
|
@ -79,27 +79,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_config.quant_specs.serialized_quant_stats));
|
||||
}
|
||||
|
||||
// Note:
|
||||
// We need to fuse composite ops before LowerStaticTensorList pass.
|
||||
// The tensorflow list is not supported right now by that pass.
|
||||
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
}
|
||||
|
||||
// This pass marks non-exported functions as symbol visibility 'private'
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
// The conversion pipeline has to follow the following orders:
|
||||
// 1) Try to convert ophint nodes if present first like ophint lstm.
|
||||
// 2) Saved model related optimization like decompose resource ops
|
||||
// 3) Convert composite functions like lstm/rnns, along with proper function
|
||||
// inlining & dce.
|
||||
// 4) Lower static tensor list pass.
|
||||
|
||||
// The ophint extractions happen before lots of other passes:
|
||||
// The assumption of ophint-extraction is each ophinted region is a black-box
|
||||
@ -122,6 +107,28 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
|
||||
pass_manager->addNestedPass<mlir::FuncOp>(
|
||||
mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
|
||||
// Note:
|
||||
// We need to fuse composite ops before LowerStaticTensorList pass.
|
||||
// The tensorflow list is not supported right now by that pass.
|
||||
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
|
||||
if (pass_config.emit_builtin_tflite_ops) {
|
||||
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
|
||||
}
|
||||
|
||||
// This pass marks non-exported functions as symbol visibility 'private'
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
mlir::tf_saved_model::
|
||||
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
|
||||
|
||||
pass_manager->addPass(mlir::createInlinerPass());
|
||||
pass_manager->addPass(mlir::createSymbolDCEPass());
|
||||
|
||||
if (pass_config.lower_tensor_list_ops) {
|
||||
// TODO(haoliang): Add this pass by default.
|
||||
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
|
||||
}
|
||||
|
||||
// This pass does resource analysis of saved model global tensors and marks
|
||||
// those deemed read-only as immutable.
|
||||
pass_manager->addPass(
|
||||
|
@ -197,7 +197,9 @@ Status MlirV1CompatGraphOptimizationPass::Run(
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext context;
|
||||
GraphImportConfig import_config;
|
||||
import_config.upgrade_legacy = true;
|
||||
// TODO(b/150959075): Running functionalization before TPU cluster formation
|
||||
// is not semantics preserving and should be disabled for now.
|
||||
import_config.upgrade_legacy = false;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto module_ref,
|
||||
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,
|
||||
|
@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
|
||||
let summary = "Clips tensor values to a specified min and max.";
|
||||
|
||||
let description = [{
|
||||
Given a tensor `t`, this operation returns a tensor of the same type and
|
||||
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
|
||||
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
|
||||
greater than `clip_value_max` are set to `clip_value_max`.
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min,
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max
|
||||
);
|
||||
|
||||
let results = (outs
|
||||
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
|
||||
);
|
||||
|
||||
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
|
||||
}
|
||||
|
||||
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
|
||||
let summary = "Converts two real numbers to a complex number.";
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -187,3 +187,199 @@ func @main() {
|
||||
%write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests while loop with access to the tensor array defined outside and its
|
||||
// gradient defined inside. The gradient creation should be moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]])
|
||||
%1:2 = "tf.While"(%ta#0, %size) {
|
||||
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
|
||||
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
|
||||
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]])
|
||||
%write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]])
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
|
||||
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
|
||||
}
|
||||
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
// CHECK-NEXT: return %[[CARG1]]
|
||||
return %arg1 : tensor<i32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests If op with access to the tensor array defined outside and its gradient
|
||||
// defined inside. The gradient creation should be moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
|
||||
%cond = "tf._SomeOp"() : () -> tensor<i1>
|
||||
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
%1 = "tf.If"(%cond, %ta#0) {
|
||||
then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK: func @then_branch(%[[TARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @then_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]])
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[TARG0]]
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
func @else_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]])
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
// CHECK: return %[[EARG0]]
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests (Stateful)PartitionedCall op with access to the tensor array defined
|
||||
// outside and its gradient defined inside. The gradient creation should be
|
||||
// moved outside.
|
||||
|
||||
// CHECK-LABEL: func @main
|
||||
func @main() -> () {
|
||||
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
|
||||
%cond = "tf._SomeOp"() : () -> tensor<i1>
|
||||
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
|
||||
// CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
// CHECK-SAME: f = @callee_tensorarray_decomposed
|
||||
%call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
|
||||
// CHECK-SAME: f = @callee_tensorarray_decomposed
|
||||
%call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""}
|
||||
: (tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: "tf.Slice"(%[[READ]],
|
||||
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
// CHECK-LABEL: func @callee
|
||||
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
|
||||
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
|
||||
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
%grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
|
||||
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
|
||||
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
|
||||
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
|
||||
// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]])
|
||||
// CHECK: return %[[CARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test the pass reports failure on unknown size.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
// expected-error @+1 {{unknown max element count}}
|
||||
%ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test the pass reports failure on unknown shape.
|
||||
|
||||
func @main(%arg0: tensor<i32>) -> () {
|
||||
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{unknown element shape}}
|
||||
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Tests that the pass reports error on ambiguous tensor array.
|
||||
|
||||
func @main(%arg0: tensor<i1>) -> () {
|
||||
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
|
||||
%ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
|
||||
%if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
|
||||
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
|
||||
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
|
||||
// expected-error @+1 {{unknown tensor array}}
|
||||
%read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
|
||||
return
|
||||
}
|
||||
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
return %arg0 : tensor<!tf.resource>
|
||||
}
|
||||
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
|
||||
return %arg1 : tensor<!tf.resource>
|
||||
}
|
||||
|
@ -40,6 +40,20 @@ class LegalizeHloToTf : public FunctionPass<LegalizeHloToTf> {
|
||||
void runOnFunction() override;
|
||||
};
|
||||
|
||||
// Returns whether the two values are guaranteed to be broadcastable to the
|
||||
// same shape, this broadcasts size 1 tensors up to any rank.
|
||||
// TODO(jpienaar): Move this to more general location.
|
||||
static bool AreBroadcastCompatible(Value x, Value y) {
|
||||
auto x_ranked = x.getType().dyn_cast<RankedTensorType>();
|
||||
auto y_ranked = y.getType().dyn_cast<RankedTensorType>();
|
||||
if (!x_ranked || !y_ranked) {
|
||||
return true;
|
||||
}
|
||||
SmallVector<int64_t, 4> resultShape;
|
||||
return OpTrait::util::getBroadcastedShape(x_ranked.getShape(),
|
||||
y_ranked.getShape(), resultShape);
|
||||
}
|
||||
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
|
||||
|
||||
/// Performs the lowering to XLA dialect.
|
||||
|
@ -20,14 +20,16 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
|
||||
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
|
||||
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
|
||||
|
||||
def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>;
|
||||
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Binary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class DirectBinaryPat<Op FromOp, Op ToOp>
|
||||
: Pat<(FromOp $l, $r, $_), (ToOp $l, $r)>;
|
||||
// Check that two values can be broadcasted together
|
||||
// TODO(jpienaar): Move somewhere more general
|
||||
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
|
||||
"types must be broadcastable">;
|
||||
|
||||
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
|
||||
[HLO_DivOp, TF_DivOp],
|
||||
@ -37,24 +39,41 @@ foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
|
||||
[HLO_MulOp, TF_MulOp],
|
||||
[HLO_PowOp, TF_PowOp],
|
||||
[HLO_DivOp, TF_RealDivOp],
|
||||
[HLO_SubOp, TF_SubOp]] in
|
||||
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
|
||||
[HLO_SubOp, TF_SubOp],
|
||||
[HLO_Atan2Op, TF_Atan2Op],
|
||||
[HLO_RemOp, TF_ModOp]] in
|
||||
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def LowerRightShiftSigned :
|
||||
Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(SignedIntTensor $r)]>;
|
||||
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
|
||||
[HLO_OrOp, TF_BitwiseOrOp],
|
||||
[HLO_XorOp, TF_BitwiseXorOp]] in
|
||||
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r)>;
|
||||
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
|
||||
[HLO_OrOp, TF_LogicalOrOp]] in
|
||||
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Unary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
foreach Mapping = [
|
||||
[HLO_AbsOp, TF_AbsOp],
|
||||
foreach Mapping = [[HLO_AbsOp, TF_AbsOp],
|
||||
[HLO_BitcastConvertOp, TF_BitcastOp],
|
||||
[HLO_CeilOp, TF_CeilOp],
|
||||
[HLO_CosOp, TF_CosOp],
|
||||
[HLO_ExpOp, TF_ExpOp],
|
||||
[HLO_Expm1Op, TF_Expm1Op],
|
||||
[HLO_FloorOp, TF_FloorOp],
|
||||
[HLO_ImagOp, TF_ImagOp],
|
||||
[HLO_IsFiniteOp, TF_IsFiniteOp],
|
||||
@ -65,8 +84,46 @@ foreach Mapping = [
|
||||
[HLO_RealOp, TF_RealOp],
|
||||
[HLO_RsqrtOp, TF_RsqrtOp],
|
||||
[HLO_SinOp, TF_SinOp],
|
||||
[HLO_SignOp, TF_SignOp],
|
||||
[HLO_SqrtOp, TF_SqrtOp],
|
||||
[HLO_TanhOp, TF_TanhOp],
|
||||
] in {
|
||||
def : Pat<(Mapping[0] $input), (Mapping[1] $input)>;
|
||||
}
|
||||
[HLO_TanhOp, TF_TanhOp]] in
|
||||
def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>;
|
||||
|
||||
def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>;
|
||||
|
||||
def : Pat<(HLO_BroadcastOp $arg, $shape),
|
||||
(TF_BroadcastToOp $arg, (TF_ConstOp $shape))>;
|
||||
def : Pat<(HLO_TransposeOp $arg, $permutation),
|
||||
(TF_TransposeOp $arg, (TF_ConstOp $permutation))>;
|
||||
def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ternary op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(HLO_ClampOp $min, $arg, $max),
|
||||
(TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>;
|
||||
def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Variadic op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def : Pat<(HLO_ConcatenateOp $inputs, $dim),
|
||||
(TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compare op patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
|
||||
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
||||
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
|
||||
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
|
||||
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
|
||||
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
|
||||
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
|
||||
[(AreBroadcastCompatible $l, $r)]>;
|
||||
|
@ -19,7 +19,8 @@ limitations under the License.
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/ADT/StringMap.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/Casting.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
@ -55,6 +56,8 @@ namespace {
|
||||
|
||||
namespace cutil = TF::collection_ops_util;
|
||||
|
||||
using std::string;
|
||||
|
||||
// A pass that converts tensor array operations to tensor operations and
|
||||
// read/assign ops on local variables. A later resource lifting pass can further
|
||||
// remove the local variables.
|
||||
@ -85,7 +88,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
|
||||
return split.emitOpError("unknown or invalid split tensor shape");
|
||||
}
|
||||
int64_t length = buffer_type.getDimSize(0) / *count;
|
||||
for (auto len : lengths_const.value().getValues<APInt>()) {
|
||||
for (const auto& len : lengths_const.value().getValues<APInt>()) {
|
||||
if (length == len.getSExtValue()) continue;
|
||||
return split.emitOpError("different split lengths are not supported");
|
||||
}
|
||||
@ -145,7 +148,7 @@ struct TensorArrayStats {
|
||||
// this is a gradient.
|
||||
bool accumulate_on_write;
|
||||
// Maps from a gradient source string to the local variable to the gradient.
|
||||
llvm::SmallDenseMap<llvm::StringRef, Value> grads;
|
||||
llvm::StringMap<Value> grads;
|
||||
};
|
||||
|
||||
LogicalResult HandleTensorArrayV3Op(
|
||||
@ -224,10 +227,7 @@ LogicalResult HandleTensorArrayWriteV3Op(
|
||||
cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
|
||||
/*keep_slice_shape=*/true);
|
||||
// Add a size-1 leading dimension to elem.
|
||||
for (auto dim : buffer.getType().cast<RankedTensorType>().getShape())
|
||||
LOG(ERROR) << " buffer : " << dim;
|
||||
auto slice_type = original_elem.getType().cast<RankedTensorType>();
|
||||
for (auto dim : slice_type.getShape()) LOG(ERROR) << " resahpe : " << dim;
|
||||
elem = builder.create<TF::ReshapeOp>(
|
||||
write.getLoc(), ArrayRef<Type>{slice_type},
|
||||
ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
|
||||
@ -339,6 +339,26 @@ LogicalResult HandleTensorArraySizeV3Op(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
|
||||
Operation* op, Value* var) {
|
||||
OpBuilder builder(op);
|
||||
*var = builder.create<TF::MlirLocalVarOp>(
|
||||
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
Value buffer;
|
||||
auto buffer_type = getElementTypeOrSelf(local_var_type)
|
||||
.cast<TF::ResourceType>()
|
||||
.getSubtypes()[0]
|
||||
.cast<RankedTensorType>();
|
||||
if (failed(cutil::CreateInitBufferValue(
|
||||
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
|
||||
buffer_type.getElementType(), builder, &buffer))) {
|
||||
return failure();
|
||||
}
|
||||
cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleTensorArrayGradV3Op(
|
||||
TF::TensorArrayGradV3Op grad,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
|
||||
@ -347,26 +367,17 @@ LogicalResult HandleTensorArrayGradV3Op(
|
||||
Value grad_var;
|
||||
auto sit = stats->find(local_var);
|
||||
if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
|
||||
auto emplace_res = sit->getSecond().grads.try_emplace(grad.source(), Value());
|
||||
auto emplace_res =
|
||||
sit->getSecond().grads.try_emplace(grad.source().str(), Value());
|
||||
if (!emplace_res.second) {
|
||||
// If the source has been assigned a grad, use it.
|
||||
grad_var = emplace_res.first->getSecond();
|
||||
grad_var = emplace_res.first->second;
|
||||
} else {
|
||||
grad_var = builder.create<TF::MlirLocalVarOp>(
|
||||
grad.getLoc(), ArrayRef<Type>{local_var.getType()}, ArrayRef<Value>{},
|
||||
ArrayRef<NamedAttribute>{});
|
||||
Value buffer;
|
||||
auto buffer_type = getElementTypeOrSelf(local_var.getType())
|
||||
.cast<TF::ResourceType>()
|
||||
.getSubtypes()[0]
|
||||
.cast<RankedTensorType>();
|
||||
if (failed(cutil::CreateInitBufferValue(
|
||||
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0),
|
||||
grad, buffer_type.getElementType(), builder, &buffer))) {
|
||||
if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
|
||||
&grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
cutil::WriteLocalVariable(grad_var, buffer, builder, grad.getLoc());
|
||||
emplace_res.first->getSecond() = grad_var;
|
||||
emplace_res.first->second = grad_var;
|
||||
// Write to a grad accumulates with previous writes.
|
||||
(*stats)[grad_var].accumulate_on_write = true;
|
||||
}
|
||||
@ -409,36 +420,454 @@ LogicalResult HandleTensorArrayScatterV3Op(
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorArrayOps(Block* block, ModuleOp module) {
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
|
||||
// Updates func's type according to its current arguments and return values.
|
||||
void UpdateFuncType(FuncOp func) {
|
||||
llvm::SmallVector<Type, 8> arg_types;
|
||||
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
|
||||
func.setType(FunctionType::get(
|
||||
arg_types,
|
||||
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
|
||||
func.getContext()));
|
||||
}
|
||||
|
||||
// Finds the accessed gradient sources for each tensor array argument.
|
||||
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
|
||||
ArrayRef<FuncOp> funcs, ModuleOp module) {
|
||||
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
|
||||
llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
|
||||
auto insert = [&](Value v, const string& source) {
|
||||
auto arg = v.cast<BlockArgument>();
|
||||
if (!arg) return;
|
||||
auto insert_res = result_sets[arg.getArgNumber()].insert(source);
|
||||
if (!insert_res.second) return;
|
||||
result[arg.getArgNumber()].push_back(source);
|
||||
};
|
||||
for (FuncOp func : funcs) {
|
||||
for (auto& op : func.front().getOperations()) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
continue;
|
||||
}
|
||||
if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
|
||||
insert(grad.handle(), grad.source().str());
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
auto body = module.lookupSymbol<FuncOp>(while_op.body());
|
||||
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
|
||||
for (const auto& entry : AccessedGradients({body, cond}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(while_op.getOperand(entry.getFirst()), source);
|
||||
}
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
|
||||
for (const auto& entry :
|
||||
AccessedGradients({then_branch, else_branch}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(if_op.getOperand(entry.getFirst() + 1), source);
|
||||
}
|
||||
}
|
||||
} else if (auto pc = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!pc.f().isa<FlatSymbolRefAttr>()) continue;
|
||||
auto callee = module.lookupSymbol<FuncOp>(pc.f().getRootReference());
|
||||
for (const auto& entry : AccessedGradients({callee}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(pc.getOperand(entry.getFirst()), source);
|
||||
}
|
||||
}
|
||||
} else if (auto spc =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
auto callee = module.lookupSymbol<FuncOp>(spc.f());
|
||||
for (const auto& entry : AccessedGradients({callee}, module)) {
|
||||
for (const string& source : entry.getSecond()) {
|
||||
insert(spc.getOperand(entry.getFirst()), source);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Contains cached information for decomposed callee functions for (stateful)
|
||||
// partitioned call ops.
|
||||
struct PartitionedCallTensorArrayOpsInfo {
|
||||
bool signature_change;
|
||||
FuncOp decomposed_callee;
|
||||
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
|
||||
arg_grads;
|
||||
llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
|
||||
};
|
||||
|
||||
// Updates a called function's input signature by adjusting resource types, and
|
||||
// adding required gradient arguments.
|
||||
void ChangeFunctionInputSignature(
|
||||
FuncOp func,
|
||||
const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
|
||||
llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
|
||||
llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
|
||||
int64_t original_args = func.getNumArguments();
|
||||
for (int64_t argnum = 0; argnum < original_args; ++argnum) {
|
||||
auto arg = func.getArgument(argnum);
|
||||
Type t = ta_arg_buffer_type(argnum);
|
||||
if (!t) continue;
|
||||
arg.setType(t);
|
||||
auto grad_it = grads.find(argnum);
|
||||
if (grad_it == grads.end()) continue;
|
||||
llvm::StringMap<Value> grads_map;
|
||||
for (const string& source : grad_it->getSecond()) {
|
||||
auto g = func.front().addArgument(t);
|
||||
(*stats)[g].accumulate_on_write = true;
|
||||
grads_map[source] = g;
|
||||
}
|
||||
auto& stat = (*stats)[arg];
|
||||
stat.accumulate_on_write = ta_accumulate_on_write(argnum);
|
||||
stat.grads = std::move(grads_map);
|
||||
}
|
||||
UpdateFuncType(func);
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorArrayOps(
|
||||
Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*);
|
||||
|
||||
LogicalResult HandleWhileOp(
|
||||
TF::WhileOp while_op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto body = module.lookupSymbol<FuncOp>(while_op.body());
|
||||
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
|
||||
auto grads = AccessedGradients({body, cond}, module);
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(while_op.getOperand(index));
|
||||
if (it == stats->end()) return nullptr;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto ta_accumulate_on_write = [&](int64_t index) {
|
||||
auto it = stats->find(while_op.getOperand(index));
|
||||
if (it == stats->end()) return false;
|
||||
return it->getSecond().accumulate_on_write;
|
||||
};
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
|
||||
ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &body_stats);
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
|
||||
ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &cond_stats);
|
||||
if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
|
||||
decomposed_partitioned_call_callees)) ||
|
||||
failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (body_stats.empty() && cond_stats.empty()) return success();
|
||||
auto old_body_ret = body.front().getTerminator();
|
||||
auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
|
||||
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
|
||||
if (!ta_arg_buffer_type(i)) continue;
|
||||
auto retval = old_body_ret->getOperand(i);
|
||||
auto arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!arg) {
|
||||
return while_op.emitOpError(
|
||||
"output tensor array does not alias input in a while loop");
|
||||
}
|
||||
for (const string& source : grads[i]) {
|
||||
new_retvals.push_back(body_stats[arg].grads[source]);
|
||||
}
|
||||
}
|
||||
OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
|
||||
old_body_ret->erase();
|
||||
UpdateFuncType(body);
|
||||
// Recreate the while op.
|
||||
auto operands = llvm::to_vector<8>(while_op.getOperands());
|
||||
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
|
||||
auto grad_it = grads.find(i);
|
||||
auto& stat = (*stats)[operands[i]];
|
||||
if (grad_it == grads.end()) continue;
|
||||
for (const string& source : grad_it->getSecond()) {
|
||||
auto it = stat.grads.find(source);
|
||||
if (it != stat.grads.end()) {
|
||||
operands.push_back(it->second);
|
||||
} else {
|
||||
Value grad_var;
|
||||
if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
|
||||
while_op, &grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
stat.grads[source] = grad_var;
|
||||
operands.push_back(grad_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
OpBuilder builder(while_op);
|
||||
auto new_while =
|
||||
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
|
||||
operands, while_op.getAttrs());
|
||||
// Clear the output shapes as it is not needed for XLA lowering.
|
||||
new_while.setAttr("output_shapes", builder.getArrayAttr({}));
|
||||
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
|
||||
if (ta_arg_buffer_type(i)) {
|
||||
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
|
||||
} else {
|
||||
while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
|
||||
}
|
||||
}
|
||||
while_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult HandleIfOp(
|
||||
TF::IfOp if_op, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
|
||||
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
|
||||
auto grads = AccessedGradients({then_branch, else_branch}, module);
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(if_op.getOperand(index + 1));
|
||||
if (it == stats->end()) return nullptr;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto ta_accumulate_on_write = [&](int64_t index) {
|
||||
auto it = stats->find(if_op.getOperand(index + 1));
|
||||
if (it == stats->end()) return false;
|
||||
return it->getSecond().accumulate_on_write;
|
||||
};
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
|
||||
ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &then_stats);
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
|
||||
ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &else_stats);
|
||||
if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
|
||||
decomposed_partitioned_call_callees)) ||
|
||||
failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
if (then_stats.empty() && else_stats.empty()) return success();
|
||||
// Recreate the if op.
|
||||
auto operands = llvm::to_vector<8>(if_op.getOperands());
|
||||
for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
|
||||
auto grad_it = grads.find(i);
|
||||
auto& stat = (*stats)[operands[i + 1]];
|
||||
if (grad_it == grads.end()) continue;
|
||||
for (const string& source : grad_it->getSecond()) {
|
||||
auto it = stat.grads.find(source);
|
||||
if (it != stat.grads.end()) {
|
||||
operands.push_back(it->second);
|
||||
} else {
|
||||
Value grad_var;
|
||||
if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
|
||||
if_op, &grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
stat.grads[source] = grad_var;
|
||||
operands.push_back(grad_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
OpBuilder builder(if_op);
|
||||
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
|
||||
then_branch.getType().getResults(),
|
||||
operands, if_op.getAttrs());
|
||||
// Clear the output shapes as it is not needed for XLA lowering.
|
||||
new_if.setAttr("output_shapes", builder.getArrayAttr({}));
|
||||
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
|
||||
auto retval = f.front().getTerminator()->getOperand(ret_ind);
|
||||
auto arg = retval.dyn_cast<BlockArgument>();
|
||||
if (!arg) return -1;
|
||||
return arg.getArgNumber();
|
||||
};
|
||||
for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
|
||||
if (!getElementTypeOrSelf(if_op.getResult(i).getType())
|
||||
.isa<TF::ResourceType>()) {
|
||||
if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
|
||||
continue;
|
||||
}
|
||||
int64_t then_forward_input = ret_forwards_input(then_branch, i);
|
||||
int64_t else_foward_input = ret_forwards_input(else_branch, i);
|
||||
if (then_forward_input != else_foward_input || then_forward_input < 0) {
|
||||
return if_op.emitOpError(
|
||||
"branches do not forward the same input resource");
|
||||
}
|
||||
if_op.getResult(i).replaceAllUsesWith(
|
||||
if_op.getOperand(then_forward_input + 1));
|
||||
}
|
||||
if_op.erase();
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename CallOp>
|
||||
LogicalResult HandlePartitionedCallOp(
|
||||
CallOp call, FuncOp callee, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
|
||||
callee, PartitionedCallTensorArrayOpsInfo());
|
||||
auto& info = emplace_res.first->getSecond();
|
||||
// Recreates the call op with info.
|
||||
auto recreate_caller = [&]() -> LogicalResult {
|
||||
auto new_operands = llvm::to_vector<8>(call.getOperands());
|
||||
for (const auto& entry : info.arg_grads) {
|
||||
auto it = stats->find(call.getOperand(entry.first));
|
||||
if (it == stats->end()) return call.emitOpError("unknown tensor array");
|
||||
for (const string& source : entry.second) {
|
||||
auto grad_it = it->getSecond().grads.find(source);
|
||||
if (grad_it != it->getSecond().grads.end()) {
|
||||
new_operands.push_back(grad_it->second);
|
||||
} else {
|
||||
Value grad_var;
|
||||
if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
|
||||
call, &grad_var))) {
|
||||
return failure();
|
||||
}
|
||||
it->getSecond().grads[source] = grad_var;
|
||||
new_operands.push_back(grad_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
OpBuilder builder(call);
|
||||
auto new_call = builder.create<CallOp>(
|
||||
call.getLoc(), info.decomposed_callee.getType().getResults(),
|
||||
new_operands, call.getAttrs());
|
||||
new_call.setAttr(
|
||||
"f", builder.getSymbolRefAttr(
|
||||
const_cast<FuncOp&>(info.decomposed_callee).getName()));
|
||||
for (const auto& entry : info.ret_forward_input) {
|
||||
call.getResult(entry.first)
|
||||
.replaceAllUsesWith(call.getOperand(entry.second));
|
||||
}
|
||||
call.replaceAllUsesWith(new_call);
|
||||
call.erase();
|
||||
return success();
|
||||
};
|
||||
if (!emplace_res.second) {
|
||||
// This callee was handled before.
|
||||
if (!info.signature_change) return success();
|
||||
return recreate_caller();
|
||||
}
|
||||
// Rewrite the callee on a cloned function.
|
||||
info.signature_change = false;
|
||||
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
|
||||
auto it = stats->find(call.getOperand(index));
|
||||
if (it == stats->end()) return nullptr;
|
||||
info.signature_change = true;
|
||||
return it->getFirst().getType();
|
||||
};
|
||||
auto ta_accumulate_on_write = [&](int64_t index) {
|
||||
auto it = stats->find(call.getOperand(index));
|
||||
if (it == stats->end()) return false;
|
||||
return it->getSecond().accumulate_on_write;
|
||||
};
|
||||
auto callee_clone = callee.clone();
|
||||
auto grads = AccessedGradients({callee_clone}, module);
|
||||
for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) {
|
||||
auto it = grads.find(i);
|
||||
if (it == grads.end()) continue;
|
||||
info.arg_grads.emplace_back(i, it->getSecond());
|
||||
}
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
|
||||
ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type,
|
||||
ta_accumulate_on_write, &callee_stats);
|
||||
if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module,
|
||||
&callee_stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
for (int64_t i = 0; i < call.getNumResults(); ++i) {
|
||||
auto ret = callee_clone.front().getTerminator()->getOperand(i);
|
||||
if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
|
||||
auto arg = ret.dyn_cast<BlockArgument>();
|
||||
if (!arg) continue;
|
||||
info.ret_forward_input.emplace_back(i, arg.getArgNumber());
|
||||
}
|
||||
|
||||
if (!info.signature_change) {
|
||||
// Signature is not modified. We do not need to keep two copies.
|
||||
info.signature_change = false;
|
||||
auto name = callee.getName();
|
||||
callee.erase();
|
||||
callee_clone.setName(name);
|
||||
SymbolTable(module).insert(callee_clone);
|
||||
} else {
|
||||
info.decomposed_callee = callee_clone;
|
||||
// Add the clone with a new name.
|
||||
auto name =
|
||||
llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed")
|
||||
.str();
|
||||
callee_clone.setName(name);
|
||||
SymbolTable(module).insert(callee_clone);
|
||||
}
|
||||
if (info.signature_change) return recreate_caller();
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult DecomposeTensorArrayOps(
|
||||
Block* block, ModuleOp module,
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
|
||||
decomposed_partitioned_call_callees) {
|
||||
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
|
||||
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
|
||||
op.replaceAllUsesWith(op.getOperands());
|
||||
op.erase();
|
||||
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayV3Op(ta, module, &stats))) {
|
||||
if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayReadV3Op(read, stats))) return failure();
|
||||
if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
|
||||
} else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayWriteV3Op(write, stats))) return failure();
|
||||
if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
|
||||
} else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayConcatV3Op(concat, stats))) return failure();
|
||||
if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
|
||||
} else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
|
||||
if (failed(HandleTensorArraySplitV3Op(split, stats))) return failure();
|
||||
if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
|
||||
} else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
|
||||
if (failed(HandleTensorArraySizeV3Op(size, stats))) return failure();
|
||||
if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
|
||||
} else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayGradV3Op(grad, &stats))) return failure();
|
||||
if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
|
||||
} else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayGatherV3Op(gather, stats))) return failure();
|
||||
if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
|
||||
} else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
|
||||
if (failed(HandleTensorArrayScatterV3Op(scatter, stats))) {
|
||||
if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
|
||||
close.erase();
|
||||
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
|
||||
if (failed(HandleWhileOp(while_op, module, stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
|
||||
if (failed(HandleIfOp(if_op, module, stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
|
||||
if (!pcall.f().isa<FlatSymbolRefAttr>()) {
|
||||
return pcall.emitOpError(
|
||||
"TensorArray decomposition does not support call with nested "
|
||||
"references.");
|
||||
}
|
||||
if (failed(HandlePartitionedCallOp(
|
||||
pcall, module.lookupSymbol<FuncOp>(pcall.f().getRootReference()),
|
||||
module, stats, decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
} else if (auto spcall =
|
||||
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
|
||||
if (failed(HandlePartitionedCallOp(
|
||||
spcall, module.lookupSymbol<FuncOp>(spcall.f()), module, stats,
|
||||
decomposed_partitioned_call_callees))) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
@ -448,7 +877,11 @@ void TensorArrayOpsDecompositionPass::runOnModule() {
|
||||
auto module = getModule();
|
||||
auto main = module.lookupSymbol<FuncOp>("main");
|
||||
if (!main) return;
|
||||
if (failed(DecomposeTensorArrayOps(&main.front(), module))) {
|
||||
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
|
||||
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>
|
||||
decomposed_partitioned_call_callees;
|
||||
if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
|
||||
&decomposed_partitioned_call_callees))) {
|
||||
signalPassFailure();
|
||||
}
|
||||
}
|
||||
|
@ -261,6 +261,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
|
||||
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
|
||||
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
|
||||
// LegalizeTFControlFlow encapsulates arguments for control flow operations
|
||||
|
@ -732,6 +732,7 @@ genrule(
|
||||
outs = ["operator_writers.inc"],
|
||||
cmd = ("$(location :operator_writer_gen) " +
|
||||
"-I external/llvm-project/mlir/include " +
|
||||
"-I external/org_tensorflow " +
|
||||
"$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " +
|
||||
" -o $@"),
|
||||
tools = [":operator_writer_gen"],
|
||||
|
@ -667,10 +667,11 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) {
|
||||
|
||||
PyLocalExecutable::PyLocalExecutable(
|
||||
std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client)
|
||||
bool tuple_arguments, DeviceAssignment device_assignment,
|
||||
PyLocalClient* client)
|
||||
: client_(client),
|
||||
device_assignment_(
|
||||
std::make_shared<DeviceAssignment>(device_assignment)) {
|
||||
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)),
|
||||
tuple_arguments_(tuple_arguments) {
|
||||
executables_.reserve(executables.size());
|
||||
for (auto& executable : executables) {
|
||||
executables_.emplace_back(std::move(executable));
|
||||
@ -727,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
|
||||
|
||||
std::unique_ptr<PyLocalBuffer> tuple_buffer;
|
||||
std::vector<PyLocalBuffer*> tupled_arguments;
|
||||
if (options.tuple_arguments) {
|
||||
if (options.tuple_arguments || tuple_arguments_) {
|
||||
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
|
||||
argument_handles, client_, device));
|
||||
tupled_arguments = {tuple_buffer.get()};
|
||||
@ -1037,7 +1038,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
|
||||
build_options));
|
||||
|
||||
return absl::make_unique<PyLocalExecutable>(
|
||||
std::move(local_executables), build_options.device_assignment(), client);
|
||||
std::move(local_executables), options.tuple_arguments,
|
||||
build_options.device_assignment(), client);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -316,6 +316,10 @@ struct CompileOptions {
|
||||
// The layouts of the arguments that the computation should expect.
|
||||
absl::optional<std::vector<Shape>> argument_layouts;
|
||||
|
||||
// If true, the arguments to the computation will be wrapped in a tuple and
|
||||
// passed as a single parameter.
|
||||
bool tuple_arguments = false;
|
||||
|
||||
// XLA's compilation time options.
|
||||
ExecutableBuildOptions executable_build_options;
|
||||
};
|
||||
@ -340,7 +344,8 @@ class PyLocalExecutable {
|
||||
CompileOptions options);
|
||||
|
||||
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
DeviceAssignment device_assignment, PyLocalClient* client);
|
||||
bool tuple_arguments, DeviceAssignment device_assignment,
|
||||
PyLocalClient* client);
|
||||
|
||||
PyLocalClient* client() const { return client_; }
|
||||
|
||||
@ -404,6 +409,10 @@ class PyLocalExecutable {
|
||||
std::vector<std::shared_ptr<LocalExecutable>> executables_;
|
||||
std::shared_ptr<DeviceAssignment> device_assignment_;
|
||||
|
||||
// True if the executables were compiled expecting arguments in a single
|
||||
// tuple.
|
||||
const bool tuple_arguments_;
|
||||
|
||||
// The replica and partition indices of device_assignment_ to be run by this
|
||||
// client. On single-host platforms without partitioning, this is all replicas
|
||||
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case
|
||||
|
@ -44,6 +44,12 @@ class CpuTransferManager : public GenericTransferManager {
|
||||
const Shape& literal_shape,
|
||||
MutableBorrowingLiteral literal) override;
|
||||
|
||||
bool CanShapedBufferBeAccessedNow(
|
||||
se::StreamExecutor* executor,
|
||||
const ShapedBuffer& device_buffer) const override {
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
|
||||
const void* source);
|
||||
|
@ -518,7 +518,9 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
<< "Invalid LLVM IR before optimizations:\n"
|
||||
<< err_stream.str()
|
||||
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
|
||||
"Rerun with --xla_dump_to to get the IR. ";
|
||||
"Rerun with --xla_dump_to to get the IR and looks for files with "
|
||||
"name containing: *"
|
||||
<< FilenameFor(*module, "", "") << "*";
|
||||
}
|
||||
|
||||
GpuVersion gpu_version = GetGpuVersion(stream_exec);
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <stack>
|
||||
#include <vector>
|
||||
@ -66,6 +67,25 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
|
||||
std::vector<int64> relative_order;
|
||||
for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
|
||||
if (shape.dimensions(dim) > 1) {
|
||||
relative_order.push_back(dim);
|
||||
}
|
||||
}
|
||||
// Now normalize the dimensions to values between 0 and true rank - 1.
|
||||
std::vector<int64> sorted_dims = relative_order;
|
||||
std::sort(sorted_dims.begin(), sorted_dims.end());
|
||||
for (int64& dim : relative_order) {
|
||||
int64 sorted_index = std::distance(
|
||||
sorted_dims.begin(),
|
||||
std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
|
||||
dim = sorted_index;
|
||||
}
|
||||
return relative_order;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
||||
@ -73,17 +93,20 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
|
||||
std::vector<HloInstruction*> params;
|
||||
AppendParams(producer, ¶ms);
|
||||
AppendParams(reduce, ¶ms);
|
||||
int64 max_rank = -1;
|
||||
const Layout* max_rank_layout;
|
||||
int64 max_true_rank = -1;
|
||||
std::vector<int64> max_rank_order;
|
||||
for (HloInstruction* param : params) {
|
||||
if (param->shape().IsArray() && param->shape().rank() > max_rank) {
|
||||
max_rank = param->shape().rank();
|
||||
max_rank_layout = ¶m->shape().layout();
|
||||
if (param->shape().IsArray() &&
|
||||
ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
|
||||
max_true_rank = ShapeUtil::TrueRank(param->shape());
|
||||
max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
|
||||
}
|
||||
}
|
||||
return absl::c_all_of(params, [&](HloInstruction* param) {
|
||||
return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) ||
|
||||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
|
||||
return !param->shape().IsArray() ||
|
||||
ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
|
||||
ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
|
||||
max_rank_order;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -91,6 +91,44 @@ TEST_F(GpuFusibleTest,
|
||||
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest,
|
||||
LayoutsAreReduceInputFusionFriendly_MixedLayoutProducerWithTrivialDim) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
mixed_input_layouts_computation {
|
||||
p0.1 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||
p1.1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
|
||||
copy = f16[128,1,32,32]{1,3,2,0} copy(p1.1)
|
||||
c0 = f16[] constant(0)
|
||||
broadcast = f16[128,1,32,32]{1,3,2,0} broadcast(c0), dimensions={}
|
||||
greater-than = pred[128,1,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
|
||||
ROOT root = f16[128,1,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
|
||||
}
|
||||
fused_reduce {
|
||||
p0.2 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||
convert = f32[128,1,32,32]{1,3,2,0} convert(p0.2)
|
||||
c0.2 = f32[] constant(0)
|
||||
ROOT reduce = f32[1]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
||||
}
|
||||
ENTRY entry {
|
||||
p0 = f16[128,1,32,32]{1,3,2,0} parameter(0)
|
||||
p1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
|
||||
loop_fusion = f16[128,1,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
|
||||
reduce_fusion = f32[1]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
|
||||
ROOT root = (f32[1]{0}, f16[128,1,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
|
||||
})"))
|
||||
.ValueOrDie();
|
||||
SCOPED_TRACE(module->ToString());
|
||||
const HloInstruction* reduce_fusion =
|
||||
module->entry_computation()->root_instruction()->operand(0);
|
||||
ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
|
||||
HloOpcode::kReduce);
|
||||
const HloInstruction* loop_fusion =
|
||||
module->entry_computation()->root_instruction()->operand(1);
|
||||
ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect);
|
||||
EXPECT_TRUE(
|
||||
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
fused_reduce {
|
||||
@ -152,17 +190,18 @@ TEST_F(GpuFusibleTest,
|
||||
}
|
||||
|
||||
TEST_F(GpuFusibleTest,
|
||||
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) {
|
||||
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumTrueRanksParamsOnly) {
|
||||
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
|
||||
broadcasting_computation {
|
||||
p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
||||
p1.1 = f32[128]{0} parameter(1)
|
||||
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0}
|
||||
p1.1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
|
||||
reshape = f32[128]{0} reshape(p1.1)
|
||||
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(reshape), dimensions={0}
|
||||
ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
|
||||
}
|
||||
ENTRY entry {
|
||||
p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
|
||||
p1 = f32[128]{0} parameter(1)
|
||||
p1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
|
||||
loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
|
||||
c0.2 = f32[] constant(0)
|
||||
ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add
|
||||
|
@ -27,6 +27,7 @@ cc_library(
|
||||
srcs = ["conv_emitter.cc"],
|
||||
hdrs = ["conv_emitter.h"],
|
||||
deps = [
|
||||
":conv_emitter_transforms",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
@ -39,6 +40,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "conv_emitter_transforms",
|
||||
srcs = ["conv_emitter_transforms.cc"],
|
||||
hdrs = ["conv_emitter_transforms.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:support",
|
||||
"@llvm-project//mlir:Affine",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "conv_emitter_test",
|
||||
srcs = ["conv_emitter_test.cc"],
|
||||
|
@ -38,6 +38,7 @@ limitations under the License.
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -109,48 +110,6 @@ ShapeInfo GetShapeInfo(
|
||||
return shape_info;
|
||||
}
|
||||
|
||||
bool IsSimpleLoop(mlir::AffineForOp loop) {
|
||||
return loop.getLowerBoundMap().isSingleConstant() &&
|
||||
loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
|
||||
loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
|
||||
std::next(loop.region().begin()) == loop.region().end();
|
||||
}
|
||||
|
||||
struct BoundAffineMap {
|
||||
mlir::AffineMap affine_map;
|
||||
std::vector<mlir::Value> operands;
|
||||
};
|
||||
|
||||
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
|
||||
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
|
||||
return {load.getAffineMap(),
|
||||
std::vector<mlir::Value>(load.getMapOperands().begin(),
|
||||
load.getMapOperands().end())};
|
||||
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
|
||||
return {store.getAffineMap(),
|
||||
std::vector<mlir::Value>(store.getMapOperands().begin(),
|
||||
store.getMapOperands().end())};
|
||||
} else {
|
||||
CHECK(false);
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
|
||||
BoundAffineMap new_affine,
|
||||
mlir::OpBuilder builder) {
|
||||
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
|
||||
return builder.create<mlir::AffineLoadOp>(
|
||||
builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
|
||||
new_affine.operands);
|
||||
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
|
||||
return builder.create<mlir::AffineStoreOp>(
|
||||
builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
|
||||
new_affine.affine_map, new_affine.operands);
|
||||
} else {
|
||||
CHECK(false);
|
||||
}
|
||||
}
|
||||
|
||||
void SetMemRef(mlir::Operation* op, mlir::Value memref) {
|
||||
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
|
||||
load.setMemRef(memref);
|
||||
@ -161,127 +120,6 @@ void SetMemRef(mlir::Operation* op, mlir::Value memref) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
|
||||
absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder) {
|
||||
std::vector<mlir::AffineForOp> loops;
|
||||
loops.reserve(upper_bounds.size());
|
||||
for (int64_t dim : upper_bounds) {
|
||||
auto loop =
|
||||
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
|
||||
loops.push_back(loop);
|
||||
builder = loop.getBodyBuilder();
|
||||
}
|
||||
return loops;
|
||||
}
|
||||
|
||||
void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
|
||||
mlir::OpBuilder builder) {
|
||||
CHECK(IsSimpleLoop(loop));
|
||||
|
||||
loop.setUpperBoundMap(mlir::AffineMap::get(
|
||||
loop.getUpperBoundMap().getNumDims(),
|
||||
loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
|
||||
}
|
||||
|
||||
// Tile a loop with trip count N by `size`. For now, N has to be a multiple of
|
||||
// size, but later this constraint will be removed.
|
||||
//
|
||||
// The major loop (with trip count N / size) stays as-is, while the minor loop
|
||||
// (with trip count `size`) will take over the body of `target`, and be placed
|
||||
// as the new body of `target`.
|
||||
//
|
||||
// `target` has to be within the same "perfectly nested loop group" as `loop`.
|
||||
// See the documentation for mlir::getPerfectlyNestedLoops.
|
||||
//
|
||||
// Example:
|
||||
// Before tiling `loop` with tile size X:
|
||||
// for (loop in N)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// // pass loop into affine maps
|
||||
// After:
|
||||
// for (loop in N / X)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// for (tiled_loop in X)
|
||||
// // rewrite all affine exprs from loop to `loop * X + tiled_loop`.
|
||||
//
|
||||
// Design note:
|
||||
// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well
|
||||
// documented about the exact tiling semantics, but the observed behavior is:
|
||||
// for (i from 0 to N)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// // pass i into affine maps
|
||||
// =>
|
||||
// for (i from 0 to N, step = X)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// for (j from i to min(i + X, N), step = 1)
|
||||
// // pass j into affine maps
|
||||
//
|
||||
// There are two differences between mlir::tile and TileLoop:
|
||||
// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs.
|
||||
// With that all index calculation is done in AffineExprs and easier to
|
||||
// analyze in a single place.
|
||||
// * TileLoop doesn't plan to use use max() and min() to resolve the issue when
|
||||
// N % X != 0. max() and min() are not representable in AffineExprs.
|
||||
// TODO(timshen): support the case where N % X != 0.
|
||||
//
|
||||
// TODO(timshen): consider the possibility to reuse mlir::tile's logic to
|
||||
// achieve the same goal.
|
||||
mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
|
||||
mlir::AffineForOp target) {
|
||||
CHECK(IsSimpleLoop(loop));
|
||||
CHECK(IsSimpleLoop(target));
|
||||
{
|
||||
llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
|
||||
getPerfectlyNestedLoops(all_loops, loop);
|
||||
CHECK(absl::c_linear_search(all_loops, target));
|
||||
}
|
||||
|
||||
auto builder = target.getBodyBuilder();
|
||||
|
||||
auto inner_loop =
|
||||
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
|
||||
{
|
||||
auto& inner_operations = inner_loop.getBody()->getOperations();
|
||||
auto& target_operations = target.getBody()->getOperations();
|
||||
|
||||
inner_operations.splice(inner_operations.begin(), target_operations,
|
||||
target_operations.begin(),
|
||||
std::prev(target_operations.end(), 2));
|
||||
|
||||
mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
|
||||
CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
|
||||
SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
|
||||
}
|
||||
|
||||
for (auto& use :
|
||||
llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
|
||||
mlir::Operation* owner = use.getOwner();
|
||||
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
|
||||
unsigned new_dim = affine_map.operands.size();
|
||||
affine_map.operands.push_back(inner_loop.getInductionVar());
|
||||
std::vector<mlir::AffineExpr> replacements;
|
||||
for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
|
||||
if (affine_map.operands[i] == loop.getInductionVar()) {
|
||||
replacements.push_back(builder.getAffineDimExpr(i) * size +
|
||||
builder.getAffineDimExpr(new_dim));
|
||||
} else {
|
||||
replacements.push_back(builder.getAffineDimExpr(i));
|
||||
}
|
||||
}
|
||||
affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
|
||||
replacements, {}, affine_map.operands.size(), 0);
|
||||
auto new_op =
|
||||
CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
|
||||
owner->replaceAllUsesWith(new_op);
|
||||
owner->erase();
|
||||
}
|
||||
return inner_loop;
|
||||
}
|
||||
|
||||
// Hoist operations out of `where`. [begin_op, end_op) must be the first
|
||||
// operations of their parent loop, and `where` must be an ancestor of that
|
||||
// parent loop.
|
||||
@ -387,21 +225,6 @@ mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
|
||||
return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
|
||||
}
|
||||
|
||||
// Sinks a segment of perfectly nested loops to the bottom. It implements this
|
||||
// by rotating the loop nest by rotate_amount.
|
||||
void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
|
||||
int rotate_amount) {
|
||||
CHECK_GE(rotate_amount, 0);
|
||||
std::vector<unsigned> permutation(loops.size());
|
||||
std::iota(permutation.begin(), permutation.end(), unsigned(0));
|
||||
std::rotate(permutation.begin(),
|
||||
permutation.begin() + loops.size() - rotate_amount,
|
||||
permutation.end());
|
||||
mlir::interchangeLoops(
|
||||
llvm::ArrayRef<mlir::AffineForOp>(loops.begin(), loops.end()),
|
||||
permutation);
|
||||
}
|
||||
|
||||
struct InitialMlirConvAnchors {
|
||||
std::vector<mlir::AffineForOp> cartesian_product_loops;
|
||||
std::vector<mlir::AffineForOp> reduction_loops;
|
||||
|
@ -0,0 +1,152 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h"
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
namespace xla {
|
||||
namespace mlir_gpu {
|
||||
|
||||
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
|
||||
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
|
||||
return {load.getAffineMap(),
|
||||
std::vector<mlir::Value>(load.getMapOperands().begin(),
|
||||
load.getMapOperands().end())};
|
||||
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
|
||||
return {store.getAffineMap(),
|
||||
std::vector<mlir::Value>(store.getMapOperands().begin(),
|
||||
store.getMapOperands().end())};
|
||||
} else {
|
||||
CHECK(false);
|
||||
}
|
||||
}
|
||||
|
||||
mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
|
||||
BoundAffineMap new_affine,
|
||||
mlir::OpBuilder builder) {
|
||||
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
|
||||
return builder.create<mlir::AffineLoadOp>(
|
||||
builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
|
||||
new_affine.operands);
|
||||
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
|
||||
return builder.create<mlir::AffineStoreOp>(
|
||||
builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
|
||||
new_affine.affine_map, new_affine.operands);
|
||||
} else {
|
||||
CHECK(false);
|
||||
}
|
||||
}
|
||||
|
||||
bool IsSimpleLoop(mlir::AffineForOp loop) {
|
||||
return loop.getLowerBoundMap().isSingleConstant() &&
|
||||
loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
|
||||
loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
|
||||
std::next(loop.region().begin()) == loop.region().end();
|
||||
}
|
||||
|
||||
std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
|
||||
absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder) {
|
||||
std::vector<mlir::AffineForOp> loops;
|
||||
loops.reserve(upper_bounds.size());
|
||||
for (int64_t dim : upper_bounds) {
|
||||
auto loop =
|
||||
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
|
||||
loops.push_back(loop);
|
||||
builder = loop.getBodyBuilder();
|
||||
}
|
||||
return loops;
|
||||
}
|
||||
|
||||
void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
|
||||
mlir::OpBuilder builder) {
|
||||
CHECK(IsSimpleLoop(loop));
|
||||
|
||||
loop.setUpperBoundMap(mlir::AffineMap::get(
|
||||
loop.getUpperBoundMap().getNumDims(),
|
||||
loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
|
||||
}
|
||||
|
||||
mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
|
||||
mlir::AffineForOp target) {
|
||||
CHECK(IsSimpleLoop(loop));
|
||||
CHECK(IsSimpleLoop(target));
|
||||
{
|
||||
llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
|
||||
getPerfectlyNestedLoops(all_loops, loop);
|
||||
CHECK(absl::c_linear_search(all_loops, target));
|
||||
}
|
||||
|
||||
auto builder = target.getBodyBuilder();
|
||||
|
||||
auto inner_loop =
|
||||
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
|
||||
{
|
||||
auto& inner_operations = inner_loop.getBody()->getOperations();
|
||||
auto& target_operations = target.getBody()->getOperations();
|
||||
|
||||
inner_operations.splice(inner_operations.begin(), target_operations,
|
||||
target_operations.begin(),
|
||||
std::prev(target_operations.end(), 2));
|
||||
|
||||
mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
|
||||
CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
|
||||
SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
|
||||
}
|
||||
|
||||
for (auto& use :
|
||||
llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
|
||||
mlir::Operation* owner = use.getOwner();
|
||||
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
|
||||
unsigned new_dim = affine_map.operands.size();
|
||||
affine_map.operands.push_back(inner_loop.getInductionVar());
|
||||
std::vector<mlir::AffineExpr> replacements;
|
||||
for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
|
||||
if (affine_map.operands[i] == loop.getInductionVar()) {
|
||||
replacements.push_back(builder.getAffineDimExpr(i) * size +
|
||||
builder.getAffineDimExpr(new_dim));
|
||||
} else {
|
||||
replacements.push_back(builder.getAffineDimExpr(i));
|
||||
}
|
||||
}
|
||||
affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
|
||||
replacements, {}, affine_map.operands.size(), 0);
|
||||
auto new_op =
|
||||
CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
|
||||
owner->replaceAllUsesWith(new_op);
|
||||
owner->erase();
|
||||
}
|
||||
return inner_loop;
|
||||
}
|
||||
|
||||
void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
|
||||
int rotate_amount) {
|
||||
CHECK_GE(rotate_amount, 0);
|
||||
std::vector<unsigned> permutation(loops.size());
|
||||
std::iota(permutation.begin(), permutation.end(), unsigned(0));
|
||||
std::rotate(permutation.begin(),
|
||||
permutation.begin() + loops.size() - rotate_amount,
|
||||
permutation.end());
|
||||
mlir::interchangeLoops(
|
||||
llvm::ArrayRef<mlir::AffineForOp>(loops.begin(), loops.end()),
|
||||
permutation);
|
||||
}
|
||||
|
||||
} // namespace mlir_gpu
|
||||
} // namespace xla
|
@ -0,0 +1,102 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_
|
||||
|
||||
#include "absl/base/integral_types.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Operation.h" // from @llvm-project
|
||||
|
||||
namespace xla {
|
||||
namespace mlir_gpu {
|
||||
|
||||
struct BoundAffineMap {
|
||||
mlir::AffineMap affine_map;
|
||||
std::vector<mlir::Value> operands;
|
||||
};
|
||||
|
||||
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op);
|
||||
mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
|
||||
BoundAffineMap new_affine,
|
||||
mlir::OpBuilder builder);
|
||||
|
||||
bool IsSimpleLoop(mlir::AffineForOp loop);
|
||||
std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
|
||||
absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder);
|
||||
void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
|
||||
mlir::OpBuilder builder);
|
||||
|
||||
// Tile a loop with trip count N by `size`. For now, N has to be a multiple of
|
||||
// size, but later this constraint will be removed.
|
||||
//
|
||||
// The major loop (with trip count N / size) stays as-is, while the minor loop
|
||||
// (with trip count `size`) will take over the body of `target`, and be placed
|
||||
// as the new body of `target`.
|
||||
//
|
||||
// `target` has to be within the same "perfectly nested loop group" as `loop`.
|
||||
// See the documentation for mlir::getPerfectlyNestedLoops.
|
||||
//
|
||||
// Example:
|
||||
// Before tiling `loop` with tile size X:
|
||||
// for (loop in N)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// // pass loop into affine maps
|
||||
// After:
|
||||
// for (loop in N / X)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// for (tiled_loop in X)
|
||||
// // rewrite all affine exprs from loop to `loop * X + tiled_loop`.
|
||||
//
|
||||
// Design note:
|
||||
// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well
|
||||
// documented about the exact tiling semantics, but the observed behavior is:
|
||||
// for (i from 0 to N)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// // pass i into affine maps
|
||||
// =>
|
||||
// for (i from 0 to N, step = X)
|
||||
// for (unrelated_loop in ...)
|
||||
// for (target in ...)
|
||||
// for (j from i to min(i + X, N), step = 1)
|
||||
// // pass j into affine maps
|
||||
//
|
||||
// There are two differences between mlir::tile and TileLoop:
|
||||
// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs.
|
||||
// With that all index calculation is done in AffineExprs and easier to
|
||||
// analyze in a single place.
|
||||
// * TileLoop doesn't plan to use use max() and min() to resolve the issue when
|
||||
// N % X != 0. max() and min() are not representable in AffineExprs.
|
||||
// TODO(timshen): support the case where N % X != 0.
|
||||
//
|
||||
// TODO(timshen): consider the possibility to reuse mlir::tile's logic to
|
||||
// achieve the same goal.
|
||||
mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
|
||||
mlir::AffineForOp target);
|
||||
|
||||
// Sinks a segment of perfectly nested loops to the bottom. It implements this
|
||||
// by rotating the loop nest by rotate_amount.
|
||||
void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
|
||||
int rotate_amount);
|
||||
|
||||
} // namespace mlir_gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_
|
@ -320,10 +320,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
end = {k, std::min(i * block_size, n)};
|
||||
}
|
||||
|
||||
if (!left_side) {
|
||||
std::swap(end[0], end[1]);
|
||||
}
|
||||
if (transpose_a) {
|
||||
if (!left_side ^ transpose_a) {
|
||||
std::swap(start[0], start[1]);
|
||||
std::swap(end[0], end[1]);
|
||||
}
|
||||
@ -337,16 +334,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
}
|
||||
|
||||
XlaOp x_update;
|
||||
auto zero = Zero(builder, S32);
|
||||
auto start_index = ConstantR0WithType(builder, S32, j * block_size);
|
||||
std::vector<XlaOp> update_starts = {start_index, zero};
|
||||
if (left_side) {
|
||||
x_update =
|
||||
BatchDot(inv_block, transpose_a, remainder, false, precision);
|
||||
} else {
|
||||
x_update =
|
||||
BatchDot(remainder, false, inv_block, transpose_a, precision);
|
||||
std::swap(update_starts[0], update_starts[1]);
|
||||
}
|
||||
|
||||
if (i == 0) {
|
||||
|
@ -458,7 +458,7 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
||||
Array2D<float> avals(spec.m, spec.m);
|
||||
avals.FillRandom(1.0);
|
||||
for (int i = 0; i < spec.m; ++i) {
|
||||
avals(i, i) += 10;
|
||||
avals(i, i) += 30;
|
||||
}
|
||||
|
||||
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
|
||||
@ -481,13 +481,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
|
||||
}
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
|
||||
ErrorSpec(1e-2, 1e-2));
|
||||
ErrorSpec(3e-2, 3e-2));
|
||||
}
|
||||
|
||||
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
|
||||
std::vector<TriangularSolveTestSpec> specs;
|
||||
for (int m : {5, 10}) {
|
||||
for (int n : {5, 10}) {
|
||||
for (int m : {5, 10, 150}) {
|
||||
for (int n : {5, 10, 150}) {
|
||||
for (bool left_side : {false, true}) {
|
||||
for (bool lower : {false, true}) {
|
||||
for (TriangularSolveOptions::Transpose transpose_a :
|
||||
|
@ -2550,6 +2550,7 @@ filegroup(
|
||||
"common_runtime/executor_factory.h",
|
||||
"common_runtime/function_optimization_registry.h",
|
||||
"common_runtime/graph_optimizer.h",
|
||||
"common_runtime/graph_view.h",
|
||||
"common_runtime/input_colocation_exemption_registry.h",
|
||||
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
|
||||
"common_runtime/local_device.h",
|
||||
@ -2613,6 +2614,7 @@ tf_cuda_library(
|
||||
"common_runtime/function_optimization_registry.cc",
|
||||
"common_runtime/graph_optimizer.cc",
|
||||
"common_runtime/graph_runner.cc",
|
||||
"common_runtime/graph_view.cc",
|
||||
"common_runtime/hierarchical_tree_broadcaster.cc",
|
||||
"common_runtime/input_colocation_exemption_registry.cc",
|
||||
"common_runtime/inspecting_placer.cc",
|
||||
|
@ -382,7 +382,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
}
|
||||
void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
||||
if (ptr != nullptr) {
|
||||
AddTraceMe("MemoryAllocation", num_bytes);
|
||||
AddTraceMe("MemoryAllocation", ptr);
|
||||
return ptr;
|
||||
}
|
||||
|
||||
@ -390,7 +390,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
if (Extend(unused_alignment, rounded_bytes)) {
|
||||
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
||||
if (ptr != nullptr) {
|
||||
AddTraceMe("MemoryAllocation", num_bytes);
|
||||
AddTraceMe("MemoryAllocation", ptr);
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
@ -403,7 +403,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
if (MergeTimestampedChunks(rounded_bytes)) {
|
||||
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
||||
if (ptr != nullptr) {
|
||||
AddTraceMe("MemoryAllocation", num_bytes);
|
||||
AddTraceMe("MemoryAllocation", ptr);
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
@ -417,7 +417,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
Extend(unused_alignment, rounded_bytes)) {
|
||||
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
|
||||
if (ptr != nullptr) {
|
||||
AddTraceMe("MemoryAllocation", num_bytes);
|
||||
AddTraceMe("MemoryAllocation", ptr);
|
||||
return ptr;
|
||||
}
|
||||
}
|
||||
@ -439,7 +439,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
|
||||
}
|
||||
|
||||
void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
|
||||
int64 requested_bytes) {
|
||||
const void* chunk_ptr) {
|
||||
// Internal users will see the memory profile with default trace level.
|
||||
auto traceme_level = profiler::TraceMeLevel::kVerbose;
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
@ -451,12 +451,17 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
|
||||
AllocatorStats stats = stats_;
|
||||
int64 bytes_available =
|
||||
memory_limit_ - stats.bytes_reserved - stats.bytes_in_use;
|
||||
BFCAllocator::Chunk* chunk =
|
||||
ChunkFromHandle(region_manager_.get_handle(chunk_ptr));
|
||||
|
||||
return absl::StrCat(traceme_name, "#allocator_name=", name_,
|
||||
",bytes_reserved=", stats.bytes_reserved,
|
||||
",bytes_allocated=", stats.bytes_in_use,
|
||||
",bytes_available=", bytes_available,
|
||||
",peak_bytes_in_use=", stats.peak_bytes_in_use,
|
||||
",requested_bytes=", requested_bytes,
|
||||
",requested_bytes=", chunk->requested_size,
|
||||
",allocation_bytes=", chunk->size,
|
||||
",addr=", reinterpret_cast<uint64>(chunk_ptr),
|
||||
",tf_op=", pending_op_name, ",id=", pending_step_id,
|
||||
"#");
|
||||
},
|
||||
@ -595,8 +600,10 @@ void BFCAllocator::DeallocateRawInternal(void* ptr) {
|
||||
BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
|
||||
CHECK(h != kInvalidChunkHandle);
|
||||
|
||||
int64 requested_bytes = ChunkFromHandle(h)->requested_size;
|
||||
MarkFree(h);
|
||||
// TraceMe needs to be added after MarkFree and before InsertFreeChunkIntoBin
|
||||
// for correct memory stats.
|
||||
AddTraceMe("MemoryDeallocation", ptr);
|
||||
|
||||
// Consider coalescing it.
|
||||
if (timing_counter_) {
|
||||
@ -609,8 +616,6 @@ void BFCAllocator::DeallocateRawInternal(void* ptr) {
|
||||
if (VLOG_IS_ON(4)) {
|
||||
LOG(INFO) << "F: " << RenderOccupancy();
|
||||
}
|
||||
|
||||
AddTraceMe("MemoryDeallocation", -requested_bytes);
|
||||
}
|
||||
|
||||
// Merges h1 and h2 when Chunk(h1)->next is h2 and Chunk(h2)->prev is c1.
|
||||
|
@ -116,8 +116,9 @@ class BFCAllocator : public Allocator {
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||
|
||||
// Add TraceMe (in memory allocation and deallocation) for memory stats
|
||||
// profiling. The requested_bytes can be negative if it's a deallocation.
|
||||
void AddTraceMe(absl::string_view traceme_name, int64 requested_bytes)
|
||||
// profiling. The chunk_ptr is passed to get information such as address,
|
||||
// chunk size and requested_size.
|
||||
void AddTraceMe(absl::string_view traceme_name, const void* chunk_ptr)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(lock_);
|
||||
|
||||
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
|
||||
|
@ -1337,10 +1337,11 @@ Status DirectSession::CreateExecutors(
|
||||
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
|
||||
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
|
||||
/*parent=*/nullptr, custom_kernel_creator, session_metadata,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}));
|
||||
Rendezvous::Factory{
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}}));
|
||||
|
||||
GraphOptimizer optimizer(optimizer_opts);
|
||||
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
|
||||
|
@ -128,11 +128,11 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
|
||||
thread::ThreadPool* thread_pool,
|
||||
DistributedFunctionLibraryRuntime* cluster_flr,
|
||||
const CustomKernelCreator* custom_kernel_creator) {
|
||||
Rendezvous::Factory rendezvous_factory =
|
||||
Rendezvous::Factory rendezvous_factory{
|
||||
[this](const int64 step_id, const DeviceMgr*, Rendezvous** r) {
|
||||
*r = CreateRendezvous(step_id);
|
||||
return Status::OK();
|
||||
};
|
||||
}};
|
||||
if (lazy_copy_function_remote_inputs_) {
|
||||
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
|
||||
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
|
||||
@ -1102,6 +1102,7 @@ Status EagerContext::UpdateRemoteMaster(
|
||||
if (rendezvous_ != nullptr) rendezvous_->Unref();
|
||||
rendezvous_ = r;
|
||||
remote_eager_workers_ = std::move(remote_eager_workers);
|
||||
pflr_->InitializeDeviceSet();
|
||||
InitPrioritizedDeviceTypeList();
|
||||
|
||||
default_executor_.ClearError();
|
||||
|
@ -583,11 +583,11 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
if (executor.Async()) {
|
||||
const DataTypeVector& output_dtypes = kernel->output_dtypes();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||
retvals[i] = TensorHandle::CreateEmptyLocalHandle(
|
||||
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
|
||||
/* op_device= */ kernel->device(),
|
||||
/* resource_device= */ kernel->OutputResourceDevice(i),
|
||||
output_dtypes[i], &ctx, &retvals[i]));
|
||||
output_dtypes[i], &ctx);
|
||||
}
|
||||
auto node = absl::make_unique<AsyncExecuteNode>(
|
||||
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
|
||||
@ -773,18 +773,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
|
||||
// remote device here. We just need to know that it is remote. If we need
|
||||
// to copy this tensor to this process, the remote end will know the
|
||||
// correct device of this handle.
|
||||
Status status = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
id, i, remote_task, output_dtypes[i], op_device, &ctx, &retvals[i]);
|
||||
if (!status.ok()) {
|
||||
for (int j = 0; j < i; ++j) {
|
||||
retvals[j]->PoisonRemote(
|
||||
errors::Internal(
|
||||
"Failed to construct unshaped remote tensor handle at index ",
|
||||
i, " for op ", op->Name()),
|
||||
op_device, ctx.GetContextViewId());
|
||||
}
|
||||
return status;
|
||||
}
|
||||
retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
id, i, remote_task, output_dtypes[i], op_device, &ctx);
|
||||
}
|
||||
|
||||
if (ctx.LazyCopyFunctionRemoteInputs()) {
|
||||
@ -1056,12 +1046,11 @@ Status EagerKernelExecute(
|
||||
|
||||
for (int i = 0; i < retvals.size(); ++i) {
|
||||
if (retvals[i] == nullptr) {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||
retvals[i] = TensorHandle::CreateLocalHandle(
|
||||
std::move(outputs[i]),
|
||||
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||
/* op_device= */ kernel->device(),
|
||||
/* resource_device= */ kernel->OutputResourceDevice(i), ctx,
|
||||
&retvals[i]));
|
||||
/* resource_device= */ kernel->OutputResourceDevice(i), ctx);
|
||||
} else {
|
||||
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
|
||||
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
|
||||
@ -1100,8 +1089,8 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
h->Ref();
|
||||
*result = h;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||
d, dstd, h->resource_device(), h->dtype, ctx, result));
|
||||
*result = TensorHandle::CreateEmptyLocalHandle(
|
||||
d, dstd, h->resource_device(), h->dtype, ctx);
|
||||
}
|
||||
|
||||
Status s;
|
||||
@ -1169,9 +1158,9 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
h->Ref();
|
||||
*result = h;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
|
||||
*result = TensorHandle::CreateEmptyLocalHandle(
|
||||
/* d= */ d, /* op_device= */ device,
|
||||
/*resource_device=*/nullptr, h->dtype, ctx, result));
|
||||
/*resource_device=*/nullptr, h->dtype, ctx);
|
||||
}
|
||||
} else {
|
||||
if (mirror) {
|
||||
@ -1194,8 +1183,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
|
||||
h->Ref();
|
||||
*result = h;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
|
||||
recv_op_id, 0, remote_task, h->dtype, device, ctx, result));
|
||||
*result = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
recv_op_id, 0, remote_task, h->dtype, device, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,8 +50,8 @@ void EagerProcessFunctionLibraryRuntime::Run(
|
||||
std::move(done));
|
||||
}
|
||||
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
||||
done =
|
||||
ApplyCleanUpToDoneCallback(cleanup_items, done, /*rendezvous=*/nullptr);
|
||||
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
|
||||
/*rendezvous=*/nullptr);
|
||||
|
||||
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
|
||||
InternalArgs* comp_args) -> Status {
|
||||
|
@ -106,40 +106,36 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
|
||||
return GetResourceHandleInfoImpl(get_resource_info);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
|
||||
TensorHandle** h) {
|
||||
TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) {
|
||||
// TODO(b/136608821): Move away from nullptr
|
||||
tensorflow::Tensor tensor = t;
|
||||
return CreateLocalHandle(std::move(tensor),
|
||||
/*d=*/nullptr,
|
||||
/*op_device=*/nullptr,
|
||||
/*ctx=*/nullptr, h);
|
||||
/*ctx=*/nullptr);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device, EagerContext* ctx,
|
||||
TensorHandle** h) {
|
||||
return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h);
|
||||
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device,
|
||||
EagerContext* ctx) {
|
||||
return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device,
|
||||
Device* resource_device,
|
||||
EagerContext* ctx, TensorHandle** h) {
|
||||
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device,
|
||||
Device* resource_device,
|
||||
EagerContext* ctx) {
|
||||
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
|
||||
*h = new TensorHandle(std::move(t), d, op_device, ctx);
|
||||
return new TensorHandle(std::move(t), d, op_device, ctx);
|
||||
} else {
|
||||
*h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
|
||||
return new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
||||
EagerContext* ctx, TensorHandle** h) {
|
||||
*h = new TensorHandle(std::move(t), d, ctx);
|
||||
|
||||
return Status::OK();
|
||||
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t,
|
||||
CustomDevice* d,
|
||||
EagerContext* ctx) {
|
||||
return new TensorHandle(std::move(t), d, ctx);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
|
||||
@ -190,13 +186,11 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
||||
<< " tensor: " << t.DeviceSafeDebugString();
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device,
|
||||
DataType dtype, EagerContext* ctx,
|
||||
TensorHandle** h) {
|
||||
*h = new TensorHandle(d, op_device, resource_device, dtype, ctx);
|
||||
|
||||
return Status::OK();
|
||||
TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device,
|
||||
DataType dtype,
|
||||
EagerContext* ctx) {
|
||||
return new TensorHandle(d, op_device, resource_device, dtype, ctx);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
@ -214,14 +208,10 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
|
||||
}
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||
const string& remote_task,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** h) {
|
||||
*h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
|
||||
|
||||
return Status::OK();
|
||||
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
|
||||
int64 op_id, int32 output_num, const string& remote_task, DataType dtype,
|
||||
Device* d, EagerContext* ctx) {
|
||||
return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||
@ -239,13 +229,11 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
|
||||
<< " device: " << VariantDeviceDebugString(device_);
|
||||
}
|
||||
|
||||
Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx,
|
||||
TensorHandle** h) {
|
||||
*h = new TensorHandle(op_id, output_num, dtype, d, ctx);
|
||||
|
||||
return Status::OK();
|
||||
TensorHandle* TensorHandle::CreateLazyRemoteHandle(int64 op_id,
|
||||
int32 output_num,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx) {
|
||||
return new TensorHandle(op_id, output_num, dtype, d, ctx);
|
||||
}
|
||||
|
||||
TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype,
|
||||
|
@ -77,27 +77,27 @@ class TensorHandle : public core::RefCounted {
|
||||
|
||||
public:
|
||||
// TensorHandle with no assigned device
|
||||
static Status CreateLocalHandle(const tensorflow::Tensor& t,
|
||||
TensorHandle** h);
|
||||
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device, EagerContext* ctx,
|
||||
TensorHandle** h);
|
||||
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device, Device* resource_device,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
static Status CreateEmptyLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device, DataType dtype,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t);
|
||||
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device, EagerContext* ctx);
|
||||
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
|
||||
Device* op_device,
|
||||
Device* resource_device,
|
||||
EagerContext* ctx);
|
||||
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t,
|
||||
CustomDevice* d, EagerContext* ctx);
|
||||
static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
|
||||
Device* resource_device,
|
||||
DataType dtype,
|
||||
EagerContext* ctx);
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||
const string& remote_task,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx, TensorHandle** h);
|
||||
static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
|
||||
const string& remote_task,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx);
|
||||
static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
|
||||
DataType dtype, Device* d,
|
||||
EagerContext* ctx);
|
||||
#endif // IS_MOBILE_PLATFORM
|
||||
|
||||
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }
|
||||
|
@ -38,15 +38,10 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
|
||||
&device_mgr, false, nullptr, nullptr, nullptr);
|
||||
TensorHandle* sync_th;
|
||||
EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr,
|
||||
ctx, &sync_th)
|
||||
.ok());
|
||||
TensorHandle* async_th;
|
||||
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr,
|
||||
DataType::DT_UINT16, ctx,
|
||||
&async_th)
|
||||
.ok());
|
||||
TensorHandle* sync_th =
|
||||
TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, ctx);
|
||||
TensorHandle* async_th = TensorHandle::CreateEmptyLocalHandle(
|
||||
nullptr, nullptr, nullptr, DataType::DT_UINT16, ctx);
|
||||
|
||||
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());
|
||||
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/common_runtime/costmodel_manager.h"
|
||||
#include "tensorflow/core/common_runtime/executor_factory.h"
|
||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||
#include "tensorflow/core/common_runtime/metrics.h"
|
||||
#include "tensorflow/core/common_runtime/pending_counts.h"
|
||||
#include "tensorflow/core/common_runtime/renamed_device.h"
|
||||
@ -124,19 +125,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
|
||||
} // namespace nodestats
|
||||
|
||||
class ExecutorImpl;
|
||||
class GraphView;
|
||||
|
||||
struct EdgeInfo {
|
||||
int dst_id;
|
||||
int output_slot : 31;
|
||||
// true if this is the last info for output_slot in the EdgeInfo list.
|
||||
bool is_last : 1;
|
||||
int input_slot;
|
||||
};
|
||||
|
||||
struct ControlEdgeInfo {
|
||||
int dst_id;
|
||||
};
|
||||
|
||||
// Time the execution of kernels (in CPU cycles). Used to dynamically identify
|
||||
// inexpensive kernels which can be dispatched inline.
|
||||
@ -148,196 +136,9 @@ struct KernelTimer {
|
||||
}
|
||||
};
|
||||
|
||||
// Compact structure representing a graph node and its associated kernel.
|
||||
//
|
||||
// Each NodeItem is an element of exactly one GraphView.
|
||||
struct NodeItem {
|
||||
NodeItem() {}
|
||||
|
||||
// The index of this node's item in its GraphView.
|
||||
int node_id = -1;
|
||||
|
||||
// Cached attributes of this node for fast lookup.
|
||||
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
|
||||
bool is_merge : 1; // True iff IsMerge(node)
|
||||
bool is_enter : 1; // True iff IsEnter(node)
|
||||
bool is_constant_enter : 1; // True iff IsEnter(node) and
|
||||
// node->GetAttr("is_constant") == true.
|
||||
bool is_exit : 1; // True iff IsExit(node)
|
||||
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
|
||||
bool is_source : 1; // True iff IsSource(node)
|
||||
// True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
|
||||
bool is_enter_exit_or_next_iter : 1;
|
||||
bool is_transfer_node : 1; // True iff IsTransferNode(node)
|
||||
bool is_initialization_op : 1; // True iff IsInitializationOp(node)
|
||||
bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
|
||||
bool is_next_iteration : 1; // True iff IsNextIteration(node)
|
||||
bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
|
||||
bool
|
||||
is_any_consumer_merge_or_control_trigger : 1; // True iff the destination
|
||||
// of any output edge is a
|
||||
// merge or control trigger
|
||||
// node.
|
||||
|
||||
// The kernel for this node.
|
||||
OpKernel* kernel = nullptr;
|
||||
|
||||
// If the kernel is a Const op, this containts points to the constant tensor.
|
||||
const Tensor* const_tensor = nullptr;
|
||||
|
||||
// Cached values of node->num_inputs() and node->num_outputs(), to
|
||||
// avoid levels of indirection.
|
||||
int num_inputs;
|
||||
int num_outputs;
|
||||
|
||||
// ExecutorImpl::tensors_[input_start] is the 1st positional input
|
||||
// for this node.
|
||||
int input_start = 0;
|
||||
|
||||
// Number of output edges, excluding control edges.
|
||||
int32 num_output_edges;
|
||||
|
||||
// Number of output control edges.
|
||||
int32 num_output_control_edges;
|
||||
|
||||
// If non-null, contains an array of num_outputs bools, where the ith bool
|
||||
// is true if and only if the ith output is consumed by another node.
|
||||
std::unique_ptr<bool[]> outputs_required;
|
||||
|
||||
gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
|
||||
return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
|
||||
num_output_edges);
|
||||
}
|
||||
|
||||
gtl::ArraySlice<EdgeInfo> output_edges() const {
|
||||
return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
|
||||
}
|
||||
|
||||
gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
|
||||
return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
|
||||
num_output_control_edges);
|
||||
}
|
||||
|
||||
DataType input_type(int i) const {
|
||||
DCHECK_LT(i, num_inputs);
|
||||
return static_cast<DataType>(input_type_base()[i]);
|
||||
}
|
||||
DataType output_type(int i) const {
|
||||
DCHECK_LT(i, num_outputs);
|
||||
return static_cast<DataType>(output_type_base()[i]);
|
||||
}
|
||||
|
||||
// Return array of per-output allocator attributes.
|
||||
const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
|
||||
|
||||
// Return array of expected input index from which each output should
|
||||
// be forwarded:
|
||||
// kNeverForward (-2) for DO NOT FORWARD (must allocate).
|
||||
// kNoReservation (-1) for no expected forwarding.
|
||||
// 0... for forward from that input.
|
||||
const int* forward_from() const { return forward_from_base(); }
|
||||
|
||||
string DebugString() const {
|
||||
string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id);
|
||||
if (is_source) {
|
||||
strings::StrAppend(&ret, " source}");
|
||||
} else {
|
||||
strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class GraphView;
|
||||
|
||||
// Variable length section starts immediately after *this
|
||||
// (uint8 is enough for DataType).
|
||||
// EdgeInfo out_edges[num_out_edges];
|
||||
// AllocatorAttributes output_attr[num_outputs];
|
||||
// int forward_from[num_outputs];
|
||||
// uint8 input_type[num_inputs];
|
||||
// uint8 output_type[num_outputs];
|
||||
|
||||
// Return pointer to variable length section.
|
||||
char* var() const {
|
||||
return const_cast<char*>(reinterpret_cast<const char*>(this) +
|
||||
sizeof(NodeItem));
|
||||
}
|
||||
|
||||
EdgeInfo* output_edge_base() const {
|
||||
return reinterpret_cast<EdgeInfo*>(var());
|
||||
}
|
||||
|
||||
ControlEdgeInfo* output_control_edge_base() const {
|
||||
return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
|
||||
num_output_edges);
|
||||
}
|
||||
|
||||
AllocatorAttributes* output_attr_base() const {
|
||||
return reinterpret_cast<AllocatorAttributes*>(
|
||||
var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) * num_output_control_edges);
|
||||
}
|
||||
int* forward_from_base() const {
|
||||
return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) *
|
||||
num_output_control_edges +
|
||||
sizeof(AllocatorAttributes) * num_outputs);
|
||||
}
|
||||
uint8* input_type_base() const {
|
||||
return reinterpret_cast<uint8*>(
|
||||
var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) * num_output_control_edges +
|
||||
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
|
||||
}
|
||||
uint8* output_type_base() const {
|
||||
return reinterpret_cast<uint8*>(
|
||||
var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) * num_output_control_edges +
|
||||
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
|
||||
sizeof(uint8) * num_inputs);
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
|
||||
};
|
||||
|
||||
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
|
||||
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
|
||||
|
||||
// Immutable view of a Graph organized for efficient execution.
|
||||
class GraphView {
|
||||
public:
|
||||
GraphView() : space_(nullptr) {}
|
||||
~GraphView();
|
||||
|
||||
Status Initialize(const Graph* g);
|
||||
Status SetAllocAttrs(const Graph* g, const Device* device);
|
||||
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
|
||||
|
||||
NodeItem* node(int32 id) const {
|
||||
DCHECK_GE(id, 0);
|
||||
DCHECK_LT(id, num_nodes_);
|
||||
uint32 offset = node_offsets_[id];
|
||||
return ((offset == kuint32max)
|
||||
? nullptr
|
||||
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
|
||||
}
|
||||
|
||||
int32 num_nodes() const { return num_nodes_; }
|
||||
|
||||
private:
|
||||
char* InitializeNode(char* ptr, const Node* n);
|
||||
size_t NodeItemBytes(const Node* n);
|
||||
|
||||
int32 num_nodes_ = 0;
|
||||
uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
|
||||
// node_offsets_[id] holds the byte offset for node w/ "id" in space_
|
||||
|
||||
char* space_; // NodeItem objects are allocated here
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
|
||||
};
|
||||
|
||||
class ExecutorImpl : public Executor {
|
||||
public:
|
||||
explicit ExecutorImpl(const LocalExecutorParams& p) : params_(p), gview_() {
|
||||
@ -499,237 +300,6 @@ class ExecutorImpl : public Executor {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
|
||||
};
|
||||
|
||||
// Infer memory allocation attributes of a node n's output,
|
||||
// based on its use node dst. Note that dst might not be directly
|
||||
// connected to n by a single edge, but might be a downstream
|
||||
// consumer of n's output by reference. *attr is updated with any
|
||||
// necessary attributes.
|
||||
Status InferAllocAttr(const Node* n, const Node* dst,
|
||||
const DeviceNameUtils::ParsedName& local_dev_name,
|
||||
AllocatorAttributes* attr);
|
||||
|
||||
GraphView::~GraphView() {
|
||||
static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
|
||||
"Update code if AllocatorAttributes gains a destructor");
|
||||
static_assert(std::is_trivially_destructible<EdgeInfo>::value,
|
||||
"Update code if EdgeInfo gains a destructor");
|
||||
for (int i = 0; i < num_nodes_; i++) {
|
||||
NodeItem* n = node(i);
|
||||
if (n != nullptr) {
|
||||
n->NodeItem::~NodeItem();
|
||||
// Memory for "n" itself is held in space_ & gets cleaned up below
|
||||
}
|
||||
}
|
||||
delete[] node_offsets_;
|
||||
delete[] space_;
|
||||
}
|
||||
|
||||
typedef std::tuple<int32, int32> OutputAndControlEdges;
|
||||
|
||||
static OutputAndControlEdges CountOutputEdges(const Node* n) {
|
||||
DCHECK_LE(n->out_edges().size(), kint32max);
|
||||
int32 num_output_edges = 0;
|
||||
int32 num_output_control_edges = 0;
|
||||
for (auto e : n->out_edges()) {
|
||||
if (IsSink(e->dst())) continue;
|
||||
if (e->IsControlEdge()) {
|
||||
++num_output_control_edges;
|
||||
} else {
|
||||
++num_output_edges;
|
||||
}
|
||||
}
|
||||
return OutputAndControlEdges(num_output_edges, num_output_control_edges);
|
||||
}
|
||||
|
||||
size_t GraphView::NodeItemBytes(const Node* n) {
|
||||
int32 num_output_edges;
|
||||
int32 num_output_control_edges;
|
||||
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
|
||||
const int num_inputs = n->num_inputs();
|
||||
const int num_outputs = n->num_outputs();
|
||||
|
||||
// Compute number of bytes needed for NodeItem and variable length data.
|
||||
// We do not subtract sizeof(var) since num_inputs/num_outputs might
|
||||
// both be zero.
|
||||
const size_t raw_bytes =
|
||||
sizeof(NodeItem) // Fixed
|
||||
+ num_output_edges * sizeof(EdgeInfo) // output_edges[...]
|
||||
+ num_output_control_edges * //
|
||||
sizeof(ControlEdgeInfo) // output_control_edges[...]
|
||||
+ num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
|
||||
+ num_outputs * sizeof(int) // forward_from[num_outputs]
|
||||
+ num_inputs * sizeof(uint8) // input_type[num_inputs]
|
||||
+ num_outputs * sizeof(uint8); // output_type[num_outputs]
|
||||
static constexpr size_t kItemAlignment = sizeof(NodeItem*);
|
||||
static_assert(kItemAlignment % alignof(NodeItem) == 0,
|
||||
"NodeItem must be aligned with kItemAlignment");
|
||||
static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
|
||||
"EdgeInfo must be aligned with kItemAlignment");
|
||||
static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0,
|
||||
"ControlEdgeInfo must be aligned with kItemAlignment");
|
||||
static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
|
||||
"AllocatorAttributes must be aligned with kItemAlignment");
|
||||
static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
|
||||
"NodeItem must be aligned with EdgeInfo");
|
||||
static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
|
||||
"NodeItem must be aligned with AllocatorAttributes");
|
||||
static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
|
||||
"EdgeInfo must be aligned with AllocatorAttributes");
|
||||
const size_t bytes =
|
||||
((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
|
||||
return bytes;
|
||||
}
|
||||
|
||||
char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
||||
const int id = n->id();
|
||||
CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor
|
||||
|
||||
const size_t bytes = NodeItemBytes(n);
|
||||
constexpr size_t kItemAlignment = sizeof(NodeItem*);
|
||||
CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
|
||||
NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
|
||||
|
||||
// We store a 32-bit offset relative to the beginning of space_, so that we
|
||||
// only need an array of 32-bit values to map from node id to the NodeItem*,
|
||||
// (versus 64 bits on most machines if we just stored an array of NodeItem*
|
||||
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
|
||||
// values as "int" vs "size_t" in CHECK_LE.
|
||||
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
|
||||
const uint32 offset = static_cast<uint32>(ptr - space_);
|
||||
node_offsets_[id] = offset;
|
||||
ptr += bytes;
|
||||
|
||||
int32 num_output_edges;
|
||||
int32 num_output_control_edges;
|
||||
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
|
||||
const int num_inputs = n->num_inputs();
|
||||
const int num_outputs = n->num_outputs();
|
||||
|
||||
new (item) NodeItem();
|
||||
item->num_inputs = num_inputs;
|
||||
item->num_outputs = num_outputs;
|
||||
item->num_output_edges = num_output_edges;
|
||||
item->num_output_control_edges = num_output_control_edges;
|
||||
|
||||
// Fill output edges.
|
||||
// Keep track of the last EdgeInfo in the EdgeInfo array that references
|
||||
// a given output slot. For all but the last, we need to do a copy of the
|
||||
// Tensor when propagating results downstream in the graph, but for the
|
||||
// last one, we can just do a move of the Tensor object to propagate it.
|
||||
gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
|
||||
EdgeInfo* dst_edge = item->output_edge_base();
|
||||
for (auto e : n->out_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
dst_edge->dst_id = e->dst()->id();
|
||||
CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
|
||||
dst_edge->output_slot = e->src_output();
|
||||
dst_edge->is_last = false;
|
||||
const int output_slot = dst_edge->output_slot;
|
||||
if (output_slot >= 0) {
|
||||
last_indices[output_slot] = dst_edge;
|
||||
}
|
||||
// NOTE: The `input_slot` will be rewritten to the frame-wide offset later
|
||||
// in `ExecutorImpl::Initialize()`.
|
||||
dst_edge->input_slot = e->dst_input();
|
||||
dst_edge++;
|
||||
}
|
||||
for (EdgeInfo* edge_info : last_indices) {
|
||||
if (edge_info != nullptr) {
|
||||
edge_info->is_last = true;
|
||||
}
|
||||
}
|
||||
ControlEdgeInfo* dst_control_edge = item->output_control_edge_base();
|
||||
for (auto e : n->out_edges()) {
|
||||
if (!e->IsControlEdge() || IsSink(e->dst())) continue;
|
||||
dst_control_edge->dst_id = e->dst()->id();
|
||||
dst_control_edge++;
|
||||
}
|
||||
|
||||
AllocatorAttributes* output_attrs = item->output_attr_base();
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
new (&output_attrs[i]) AllocatorAttributes();
|
||||
}
|
||||
|
||||
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
|
||||
uint8* input_types = item->input_type_base();
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
input_types[i] = static_cast<uint8>(n->input_type(i));
|
||||
DCHECK_EQ(item->input_type(i), n->input_type(i));
|
||||
}
|
||||
|
||||
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
|
||||
{
|
||||
std::vector<int> forward_input;
|
||||
Status fwd_status =
|
||||
GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
|
||||
std::vector<int> scoped_allocator_attrs;
|
||||
Status sa_status =
|
||||
GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
|
||||
|
||||
int* forward_from = item->forward_from_base();
|
||||
uint8* output_types = item->output_type_base();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_types[i] = static_cast<uint8>(n->output_type(i));
|
||||
DCHECK_EQ(item->output_type(i), n->output_type(i));
|
||||
|
||||
forward_from[i] = OpKernelContext::Params::kNoReservation;
|
||||
if (sa_status.ok()) {
|
||||
for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
|
||||
if (scoped_allocator_attrs[j] == i) {
|
||||
// This output slot must be explicitly allocated from a
|
||||
// ScopedAllocator.
|
||||
forward_from[i] = OpKernelContext::Params::kNeverForward;
|
||||
DCHECK_EQ(output_attrs[i].scope_id, 0);
|
||||
output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fwd_status.ok() &&
|
||||
forward_from[i] == OpKernelContext::Params::kNoReservation) {
|
||||
DCHECK_EQ(forward_input.size() % 2, 0);
|
||||
for (int j = 0; j < forward_input.size(); j += 2) {
|
||||
if (forward_input[j + 1] == i) {
|
||||
DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
|
||||
forward_from[i] = forward_input[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
Status GraphView::Initialize(const Graph* g) {
|
||||
CHECK(node_offsets_ == nullptr);
|
||||
const int num_nodes = g->num_node_ids();
|
||||
num_nodes_ = num_nodes;
|
||||
size_t total_bytes = 0;
|
||||
for (const Node* n : g->nodes()) {
|
||||
if (n->out_edges().size() > kint32max) {
|
||||
return errors::InvalidArgument(
|
||||
"The executor cannot handle nodes with more than ", kint32max,
|
||||
" output edges. Node ", n->name(), " had ", n->out_edges().size(),
|
||||
" output edges.");
|
||||
}
|
||||
total_bytes += NodeItemBytes(n);
|
||||
}
|
||||
|
||||
node_offsets_ = new uint32[num_nodes];
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
node_offsets_[i] = kuint32max;
|
||||
}
|
||||
|
||||
space_ = new char[total_bytes]; // NodeItem objects are allocated here
|
||||
char* ptr = space_;
|
||||
for (const Node* n : g->nodes()) {
|
||||
ptr = InitializeNode(ptr, n);
|
||||
}
|
||||
CHECK_EQ(ptr, space_ + total_bytes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GetMaxPendingCounts(const Node* n, size_t* max_pending,
|
||||
size_t* max_dead_count) {
|
||||
const size_t num_in_edges = n->in_edges().size();
|
||||
@ -908,154 +478,6 @@ bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
|
||||
return false;
|
||||
}
|
||||
|
||||
void GraphView::SetScopedAllocatorAttrs(
|
||||
const std::vector<const Node*>& sa_nodes) {
|
||||
for (const Node* sa : sa_nodes) {
|
||||
NodeItem* sa_item = node(sa->id());
|
||||
AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
|
||||
// Control edges out of the ScopedAllocator should be use instances, but may
|
||||
// include a few other nodes.
|
||||
for (const auto& e : sa->out_edges()) {
|
||||
if (IsSink(e->dst()) || !e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
Node* use_node = e->dst();
|
||||
NodeItem* item = node(use_node->id());
|
||||
AllocatorAttributes* use_attrs = item->output_attr_base();
|
||||
std::vector<int> scoped_allocator_attrs;
|
||||
Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
|
||||
&scoped_allocator_attrs);
|
||||
if (!s.ok()) {
|
||||
VLOG(2) << "Failed to find expected ScopedAllocator attr on "
|
||||
<< use_node->name();
|
||||
continue;
|
||||
}
|
||||
// There can be more than one output using ScopedAllocation, but this
|
||||
// analysis assumes they use the same ScopedAllocator.
|
||||
for (const auto& e : use_node->out_edges()) {
|
||||
if (IsSink(e->dst()) || !e->IsControlEdge()) {
|
||||
AllocatorAttributes attr;
|
||||
if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
|
||||
e->src_output(), &attr)) {
|
||||
// Set the scope_id on this use instance node.
|
||||
(use_attrs + e->src_output())->Merge(attr);
|
||||
// Propagate the other attributes of this node back to the SA node.
|
||||
attr = *(use_attrs + e->src_output());
|
||||
attr.scope_id = 0;
|
||||
sa_attrs->Merge(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
|
||||
Status s;
|
||||
DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
|
||||
|
||||
std::vector<const Node*> scoped_allocator_instances;
|
||||
for (const Node* n : g->nodes()) {
|
||||
NodeItem* item = node(n->id());
|
||||
AllocatorAttributes* attrs = item->output_attr_base();
|
||||
if (IsScopedAllocator(n)) {
|
||||
scoped_allocator_instances.push_back(n);
|
||||
}
|
||||
|
||||
// Examine the out edges of each node looking for special use
|
||||
// cases that may affect memory allocation attributes.
|
||||
for (const auto& e : n->out_edges()) {
|
||||
if (!e->IsControlEdge()) {
|
||||
AllocatorAttributes attr;
|
||||
s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
|
||||
if (!s.ok()) return s;
|
||||
if (attr.value != 0 || attr.scope_id != 0) {
|
||||
attrs[e->src_output()].Merge(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int out = 0; out < n->num_outputs(); out++) {
|
||||
const OpKernel* op_kernel = item->kernel;
|
||||
DCHECK_LT(out, op_kernel->output_memory_types().size());
|
||||
bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
|
||||
if (on_host) {
|
||||
AllocatorAttributes h;
|
||||
h.set_on_host(on_host);
|
||||
attrs[out].Merge(h);
|
||||
}
|
||||
}
|
||||
}
|
||||
SetScopedAllocatorAttrs(scoped_allocator_instances);
|
||||
return s;
|
||||
}
|
||||
|
||||
Status InferAllocAttr(const Node* n, const Node* dst,
|
||||
const DeviceNameUtils::ParsedName& local_dev_name,
|
||||
AllocatorAttributes* attr) {
|
||||
Status s;
|
||||
// Note that it's possible for *n to be a Recv and *dst to be a Send,
|
||||
// so these two cases are not mutually exclusive.
|
||||
if (IsRecv(n)) {
|
||||
string src_name;
|
||||
s = GetNodeAttr(n->attrs(), "send_device", &src_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_src_name;
|
||||
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
|
||||
s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
|
||||
n->name());
|
||||
return s;
|
||||
}
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
|
||||
// Value is going to be the sink of an RPC.
|
||||
attr->set_nic_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
|
||||
} else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
|
||||
parsed_src_name.type != "CPU") {
|
||||
// Value is going to be the sink of a local DMA from GPU to CPU (or
|
||||
// other types of accelerators).
|
||||
attr->set_gpu_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
|
||||
} else {
|
||||
VLOG(2) << "default alloc case local type " << local_dev_name.type
|
||||
<< " remote type " << parsed_src_name.type;
|
||||
}
|
||||
}
|
||||
if (IsSend(dst)) {
|
||||
string dst_name;
|
||||
s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_dst_name;
|
||||
if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
|
||||
s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
|
||||
n->name());
|
||||
return s;
|
||||
}
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
|
||||
// Value is going to be the source of an RPC.
|
||||
attr->set_nic_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the source of an RPC out";
|
||||
} else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
|
||||
parsed_dst_name.type != "CPU") {
|
||||
// Value is going to be the source of a local DMA from CPU to GPU (or
|
||||
// other types of accelerators).
|
||||
// Note that this does not cover the case where the allocation of the
|
||||
// output tensor is not generated by the src: n.
|
||||
attr->set_gpu_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
|
||||
} else {
|
||||
VLOG(2) << "default alloc case local type " << local_dev_name.type
|
||||
<< " remote type " << parsed_dst_name.type;
|
||||
}
|
||||
}
|
||||
if (n->IsCollective()) {
|
||||
// We'll make the sweeping assumption that any collective op is going
|
||||
// to be involved in network i/o.
|
||||
attr->set_nic_compatible(true);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
// The state associated with one invocation of ExecutorImpl::Run.
|
||||
// ExecutorState dispatches nodes when they become ready and keeps
|
||||
// track of how many predecessors of a node have not done (pending_).
|
||||
|
@ -164,10 +164,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*thread_pool=*/nullptr,
|
||||
/*parent=*/nullptr, /*custom_kernel_creator=*/nullptr,
|
||||
/*session_metadata=*/nullptr,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}));
|
||||
Rendezvous::Factory{
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}}));
|
||||
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
|
||||
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");
|
||||
|
@ -68,10 +68,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool,
|
||||
/*parent=*/nullptr, /*custom_kernel_creator=*/nullptr,
|
||||
/*session_metadata=*/nullptr,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}));
|
||||
Rendezvous::Factory{
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}}));
|
||||
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
}
|
||||
|
||||
|
442
tensorflow/core/common_runtime/graph_view.cc
Normal file
442
tensorflow/core/common_runtime/graph_view.cc
Normal file
@ -0,0 +1,442 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/graph_view.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <deque>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/graph/edgeset.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
string NodeItem::DebugString() const {
|
||||
string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id);
|
||||
if (is_source) {
|
||||
strings::StrAppend(&ret, " source}");
|
||||
} else {
|
||||
strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}");
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
GraphView::~GraphView() {
|
||||
static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
|
||||
"Update code if AllocatorAttributes gains a destructor");
|
||||
static_assert(std::is_trivially_destructible<EdgeInfo>::value,
|
||||
"Update code if EdgeInfo gains a destructor");
|
||||
for (int i = 0; i < num_nodes_; i++) {
|
||||
NodeItem* n = node(i);
|
||||
if (n != nullptr) {
|
||||
n->NodeItem::~NodeItem();
|
||||
// Memory for "n" itself is held in space_ & gets cleaned up below
|
||||
}
|
||||
}
|
||||
delete[] node_offsets_;
|
||||
delete[] space_;
|
||||
}
|
||||
|
||||
namespace {
|
||||
typedef std::tuple<int32, int32> OutputAndControlEdges;
|
||||
|
||||
OutputAndControlEdges CountOutputEdges(const Node* n) {
|
||||
DCHECK_LE(n->out_edges().size(), kint32max);
|
||||
int32 num_output_edges = 0;
|
||||
int32 num_output_control_edges = 0;
|
||||
for (auto e : n->out_edges()) {
|
||||
if (IsSink(e->dst())) continue;
|
||||
if (e->IsControlEdge()) {
|
||||
++num_output_control_edges;
|
||||
} else {
|
||||
++num_output_edges;
|
||||
}
|
||||
}
|
||||
return OutputAndControlEdges(num_output_edges, num_output_control_edges);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
size_t GraphView::NodeItemBytes(const Node* n) {
|
||||
int32 num_output_edges;
|
||||
int32 num_output_control_edges;
|
||||
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
|
||||
const int num_inputs = n->num_inputs();
|
||||
const int num_outputs = n->num_outputs();
|
||||
|
||||
// Compute number of bytes needed for NodeItem and variable length data.
|
||||
// We do not subtract sizeof(var) since num_inputs/num_outputs might
|
||||
// both be zero.
|
||||
const size_t raw_bytes =
|
||||
sizeof(NodeItem) // Fixed
|
||||
+ num_output_edges * sizeof(EdgeInfo) // output_edges[...]
|
||||
+ num_output_control_edges * //
|
||||
sizeof(ControlEdgeInfo) // output_control_edges[...]
|
||||
+ num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
|
||||
+ num_outputs * sizeof(int) // forward_from[num_outputs]
|
||||
+ num_inputs * sizeof(uint8) // input_type[num_inputs]
|
||||
+ num_outputs * sizeof(uint8); // output_type[num_outputs]
|
||||
static constexpr size_t kItemAlignment = sizeof(NodeItem*);
|
||||
static_assert(kItemAlignment % alignof(NodeItem) == 0,
|
||||
"NodeItem must be aligned with kItemAlignment");
|
||||
static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
|
||||
"EdgeInfo must be aligned with kItemAlignment");
|
||||
static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0,
|
||||
"ControlEdgeInfo must be aligned with kItemAlignment");
|
||||
static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
|
||||
"AllocatorAttributes must be aligned with kItemAlignment");
|
||||
static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
|
||||
"NodeItem must be aligned with EdgeInfo");
|
||||
static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
|
||||
"NodeItem must be aligned with AllocatorAttributes");
|
||||
static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
|
||||
"EdgeInfo must be aligned with AllocatorAttributes");
|
||||
const size_t bytes =
|
||||
((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
|
||||
return bytes;
|
||||
}
|
||||
|
||||
char* GraphView::InitializeNode(char* ptr, const Node* n) {
|
||||
const int id = n->id();
|
||||
CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor
|
||||
|
||||
const size_t bytes = NodeItemBytes(n);
|
||||
constexpr size_t kItemAlignment = sizeof(NodeItem*);
|
||||
CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
|
||||
NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
|
||||
|
||||
// We store a 32-bit offset relative to the beginning of space_, so that we
|
||||
// only need an array of 32-bit values to map from node id to the NodeItem*,
|
||||
// (versus 64 bits on most machines if we just stored an array of NodeItem*
|
||||
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
|
||||
// values as "int" vs "size_t" in CHECK_LE.
|
||||
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
|
||||
const uint32 offset = static_cast<uint32>(ptr - space_);
|
||||
node_offsets_[id] = offset;
|
||||
ptr += bytes;
|
||||
|
||||
int32 num_output_edges;
|
||||
int32 num_output_control_edges;
|
||||
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
|
||||
const int num_inputs = n->num_inputs();
|
||||
const int num_outputs = n->num_outputs();
|
||||
|
||||
new (item) NodeItem();
|
||||
item->num_inputs = num_inputs;
|
||||
item->num_outputs = num_outputs;
|
||||
item->num_output_edges = num_output_edges;
|
||||
item->num_output_control_edges = num_output_control_edges;
|
||||
|
||||
// Fill output edges.
|
||||
// Keep track of the last EdgeInfo in the EdgeInfo array that references
|
||||
// a given output slot. For all but the last, we need to do a copy of the
|
||||
// Tensor when propagating results downstream in the graph, but for the
|
||||
// last one, we can just do a move of the Tensor object to propagate it.
|
||||
gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
|
||||
EdgeInfo* dst_edge = item->output_edge_base();
|
||||
for (auto e : n->out_edges()) {
|
||||
if (e->IsControlEdge()) continue;
|
||||
dst_edge->dst_id = e->dst()->id();
|
||||
CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
|
||||
dst_edge->output_slot = e->src_output();
|
||||
dst_edge->is_last = false;
|
||||
const int output_slot = dst_edge->output_slot;
|
||||
if (output_slot >= 0) {
|
||||
last_indices[output_slot] = dst_edge;
|
||||
}
|
||||
// NOTE: The `input_slot` will be rewritten to the frame-wide offset later
|
||||
// in `ExecutorImpl::Initialize()`.
|
||||
dst_edge->input_slot = e->dst_input();
|
||||
dst_edge++;
|
||||
}
|
||||
for (EdgeInfo* edge_info : last_indices) {
|
||||
if (edge_info != nullptr) {
|
||||
edge_info->is_last = true;
|
||||
}
|
||||
}
|
||||
ControlEdgeInfo* dst_control_edge = item->output_control_edge_base();
|
||||
for (auto e : n->out_edges()) {
|
||||
if (!e->IsControlEdge() || IsSink(e->dst())) continue;
|
||||
dst_control_edge->dst_id = e->dst()->id();
|
||||
dst_control_edge++;
|
||||
}
|
||||
|
||||
AllocatorAttributes* output_attrs = item->output_attr_base();
|
||||
for (int i = 0; i < num_outputs; i++) {
|
||||
new (&output_attrs[i]) AllocatorAttributes();
|
||||
}
|
||||
|
||||
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
|
||||
uint8* input_types = item->input_type_base();
|
||||
for (int i = 0; i < num_inputs; i++) {
|
||||
input_types[i] = static_cast<uint8>(n->input_type(i));
|
||||
DCHECK_EQ(item->input_type(i), n->input_type(i));
|
||||
}
|
||||
|
||||
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
|
||||
{
|
||||
std::vector<int> forward_input;
|
||||
Status fwd_status =
|
||||
GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
|
||||
std::vector<int> scoped_allocator_attrs;
|
||||
Status sa_status =
|
||||
GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
|
||||
|
||||
int* forward_from = item->forward_from_base();
|
||||
uint8* output_types = item->output_type_base();
|
||||
for (int i = 0; i < num_outputs; ++i) {
|
||||
output_types[i] = static_cast<uint8>(n->output_type(i));
|
||||
DCHECK_EQ(item->output_type(i), n->output_type(i));
|
||||
|
||||
forward_from[i] = OpKernelContext::Params::kNoReservation;
|
||||
if (sa_status.ok()) {
|
||||
for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
|
||||
if (scoped_allocator_attrs[j] == i) {
|
||||
// This output slot must be explicitly allocated from a
|
||||
// ScopedAllocator.
|
||||
forward_from[i] = OpKernelContext::Params::kNeverForward;
|
||||
DCHECK_EQ(output_attrs[i].scope_id, 0);
|
||||
output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
if (fwd_status.ok() &&
|
||||
forward_from[i] == OpKernelContext::Params::kNoReservation) {
|
||||
DCHECK_EQ(forward_input.size() % 2, 0);
|
||||
for (int j = 0; j < forward_input.size(); j += 2) {
|
||||
if (forward_input[j + 1] == i) {
|
||||
DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
|
||||
forward_from[i] = forward_input[j];
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ptr;
|
||||
}
|
||||
|
||||
Status GraphView::Initialize(const Graph* g) {
|
||||
CHECK(node_offsets_ == nullptr);
|
||||
const int num_nodes = g->num_node_ids();
|
||||
num_nodes_ = num_nodes;
|
||||
size_t total_bytes = 0;
|
||||
for (const Node* n : g->nodes()) {
|
||||
if (n->out_edges().size() > kint32max) {
|
||||
return errors::InvalidArgument(
|
||||
"The executor cannot handle nodes with more than ", kint32max,
|
||||
" output edges. Node ", n->name(), " had ", n->out_edges().size(),
|
||||
" output edges.");
|
||||
}
|
||||
total_bytes += NodeItemBytes(n);
|
||||
}
|
||||
|
||||
node_offsets_ = new uint32[num_nodes];
|
||||
for (int i = 0; i < num_nodes; i++) {
|
||||
node_offsets_[i] = kuint32max;
|
||||
}
|
||||
|
||||
space_ = new char[total_bytes]; // NodeItem objects are allocated here
|
||||
char* ptr = space_;
|
||||
for (const Node* n : g->nodes()) {
|
||||
ptr = InitializeNode(ptr, n);
|
||||
}
|
||||
CHECK_EQ(ptr, space_ + total_bytes);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
// If a Node has been marked to use a ScopedAllocator x for output i, then
|
||||
// sc_attr will contain the subsequence (i, x) at an even offset. This function
|
||||
// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
|
||||
// only allow one ScopedAllocator use per Node.
|
||||
bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
|
||||
int output_index,
|
||||
AllocatorAttributes* alloc_attr) {
|
||||
DCHECK_LE(2, sc_attr.size());
|
||||
for (int i = 0; i < sc_attr.size(); i += 2) {
|
||||
if (sc_attr[i] == output_index) {
|
||||
CHECK_EQ(alloc_attr->scope_id, 0);
|
||||
alloc_attr->scope_id = sc_attr[i + 1];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void GraphView::SetScopedAllocatorAttrs(
|
||||
const std::vector<const Node*>& sa_nodes) {
|
||||
for (const Node* sa : sa_nodes) {
|
||||
NodeItem* sa_item = node(sa->id());
|
||||
AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
|
||||
// Control edges out of the ScopedAllocator should be use instances, but may
|
||||
// include a few other nodes.
|
||||
for (const auto& e : sa->out_edges()) {
|
||||
if (IsSink(e->dst()) || !e->IsControlEdge()) {
|
||||
continue;
|
||||
}
|
||||
Node* use_node = e->dst();
|
||||
NodeItem* item = node(use_node->id());
|
||||
AllocatorAttributes* use_attrs = item->output_attr_base();
|
||||
std::vector<int> scoped_allocator_attrs;
|
||||
Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
|
||||
&scoped_allocator_attrs);
|
||||
if (!s.ok()) {
|
||||
VLOG(2) << "Failed to find expected ScopedAllocator attr on "
|
||||
<< use_node->name();
|
||||
continue;
|
||||
}
|
||||
// There can be more than one output using ScopedAllocation, but this
|
||||
// analysis assumes they use the same ScopedAllocator.
|
||||
for (const auto& e : use_node->out_edges()) {
|
||||
if (IsSink(e->dst()) || !e->IsControlEdge()) {
|
||||
AllocatorAttributes attr;
|
||||
if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
|
||||
e->src_output(), &attr)) {
|
||||
// Set the scope_id on this use instance node.
|
||||
(use_attrs + e->src_output())->Merge(attr);
|
||||
// Propagate the other attributes of this node back to the SA node.
|
||||
attr = *(use_attrs + e->src_output());
|
||||
attr.scope_id = 0;
|
||||
sa_attrs->Merge(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
Status InferAllocAttr(const Node* n, const Node* dst,
|
||||
const DeviceNameUtils::ParsedName& local_dev_name,
|
||||
AllocatorAttributes* attr) {
|
||||
Status s;
|
||||
// Note that it's possible for *n to be a Recv and *dst to be a Send,
|
||||
// so these two cases are not mutually exclusive.
|
||||
if (IsRecv(n)) {
|
||||
string src_name;
|
||||
s = GetNodeAttr(n->attrs(), "send_device", &src_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_src_name;
|
||||
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
|
||||
s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
|
||||
n->name());
|
||||
return s;
|
||||
}
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
|
||||
// Value is going to be the sink of an RPC.
|
||||
attr->set_nic_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
|
||||
} else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
|
||||
parsed_src_name.type != "CPU") {
|
||||
// Value is going to be the sink of a local DMA from GPU to CPU (or
|
||||
// other types of accelerators).
|
||||
attr->set_gpu_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
|
||||
} else {
|
||||
VLOG(2) << "default alloc case local type " << local_dev_name.type
|
||||
<< " remote type " << parsed_src_name.type;
|
||||
}
|
||||
}
|
||||
if (IsSend(dst)) {
|
||||
string dst_name;
|
||||
s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
|
||||
if (!s.ok()) return s;
|
||||
DeviceNameUtils::ParsedName parsed_dst_name;
|
||||
if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
|
||||
s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
|
||||
n->name());
|
||||
return s;
|
||||
}
|
||||
if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
|
||||
// Value is going to be the source of an RPC.
|
||||
attr->set_nic_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the source of an RPC out";
|
||||
} else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
|
||||
parsed_dst_name.type != "CPU") {
|
||||
// Value is going to be the source of a local DMA from CPU to GPU (or
|
||||
// other types of accelerators).
|
||||
// Note that this does not cover the case where the allocation of the
|
||||
// output tensor is not generated by the src: n.
|
||||
attr->set_gpu_compatible(true);
|
||||
VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
|
||||
} else {
|
||||
VLOG(2) << "default alloc case local type " << local_dev_name.type
|
||||
<< " remote type " << parsed_dst_name.type;
|
||||
}
|
||||
}
|
||||
if (n->IsCollective()) {
|
||||
// We'll make the sweeping assumption that any collective op is going
|
||||
// to be involved in network i/o.
|
||||
attr->set_nic_compatible(true);
|
||||
}
|
||||
return s;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
|
||||
Status s;
|
||||
DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
|
||||
|
||||
std::vector<const Node*> scoped_allocator_instances;
|
||||
for (const Node* n : g->nodes()) {
|
||||
NodeItem* item = node(n->id());
|
||||
AllocatorAttributes* attrs = item->output_attr_base();
|
||||
if (IsScopedAllocator(n)) {
|
||||
scoped_allocator_instances.push_back(n);
|
||||
}
|
||||
|
||||
// Examine the out edges of each node looking for special use
|
||||
// cases that may affect memory allocation attributes.
|
||||
for (const auto& e : n->out_edges()) {
|
||||
if (!e->IsControlEdge()) {
|
||||
AllocatorAttributes attr;
|
||||
s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
|
||||
if (!s.ok()) return s;
|
||||
if (attr.value != 0 || attr.scope_id != 0) {
|
||||
attrs[e->src_output()].Merge(attr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int out = 0; out < n->num_outputs(); out++) {
|
||||
const OpKernel* op_kernel = item->kernel;
|
||||
DCHECK_LT(out, op_kernel->output_memory_types().size());
|
||||
bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
|
||||
if (on_host) {
|
||||
AllocatorAttributes h;
|
||||
h.set_on_host(on_host);
|
||||
attrs[out].Merge(h);
|
||||
}
|
||||
}
|
||||
}
|
||||
SetScopedAllocatorAttrs(scoped_allocator_instances);
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
240
tensorflow/core/common_runtime/graph_view.h
Normal file
240
tensorflow/core/common_runtime/graph_view.h
Normal file
@ -0,0 +1,240 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
|
||||
#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class Device;
|
||||
class Graph;
|
||||
class Node;
|
||||
class OpKernel;
|
||||
class Tensor;
|
||||
|
||||
// Represents a single data edge in a `NodeItem`.
|
||||
struct EdgeInfo {
|
||||
// The node ID of the destination in the containing `GraphView`.
|
||||
int dst_id;
|
||||
// The index of the output that produces values on this edge.
|
||||
int output_slot : 31;
|
||||
// true if this is the last info for output_slot in the EdgeInfo list.
|
||||
bool is_last : 1;
|
||||
// The index of the input that consumes values on this edge.
|
||||
int input_slot;
|
||||
};
|
||||
|
||||
// Represents a single control edge in a `NodeItem`.
|
||||
struct ControlEdgeInfo {
|
||||
// The node ID of the destination in the containing `GraphView`.
|
||||
int dst_id;
|
||||
};
|
||||
|
||||
// Compact structure representing a graph node and its associated kernel.
|
||||
//
|
||||
// Each NodeItem is an element of exactly one GraphView.
|
||||
struct NodeItem {
|
||||
// The index of this node's item in its GraphView.
|
||||
int node_id = -1;
|
||||
|
||||
// Cached attributes of this node for fast lookup.
|
||||
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
|
||||
bool is_merge : 1; // True iff IsMerge(node)
|
||||
bool is_enter : 1; // True iff IsEnter(node)
|
||||
bool is_constant_enter : 1; // True iff IsEnter(node) and
|
||||
// node->GetAttr("is_constant") == true.
|
||||
bool is_exit : 1; // True iff IsExit(node)
|
||||
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
|
||||
bool is_source : 1; // True iff IsSource(node)
|
||||
// True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
|
||||
bool is_enter_exit_or_next_iter : 1;
|
||||
bool is_transfer_node : 1; // True iff IsTransferNode(node)
|
||||
bool is_initialization_op : 1; // True iff IsInitializationOp(node)
|
||||
bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
|
||||
bool is_next_iteration : 1; // True iff IsNextIteration(node)
|
||||
bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
|
||||
bool
|
||||
is_any_consumer_merge_or_control_trigger : 1; // True iff the destination
|
||||
// of any output edge is a
|
||||
// merge or control trigger
|
||||
// node.
|
||||
|
||||
// The kernel for this node.
|
||||
OpKernel* kernel = nullptr;
|
||||
|
||||
// If the kernel is a Const op, this containts points to the constant tensor.
|
||||
const Tensor* const_tensor = nullptr;
|
||||
|
||||
// Cached values of node->num_inputs() and node->num_outputs(), to
|
||||
// avoid levels of indirection.
|
||||
int num_inputs;
|
||||
int num_outputs;
|
||||
|
||||
// ExecutorImpl::tensors_[input_start] is the 1st positional input
|
||||
// for this node.
|
||||
int input_start = 0;
|
||||
|
||||
// Number of output edges, excluding control edges.
|
||||
int32 num_output_edges;
|
||||
|
||||
// Number of output control edges.
|
||||
int32 num_output_control_edges;
|
||||
|
||||
// If non-null, contains an array of num_outputs bools, where the ith bool
|
||||
// is true if and only if the ith output is consumed by another node.
|
||||
std::unique_ptr<bool[]> outputs_required;
|
||||
|
||||
gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
|
||||
return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
|
||||
num_output_edges);
|
||||
}
|
||||
|
||||
gtl::ArraySlice<EdgeInfo> output_edges() const {
|
||||
return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
|
||||
}
|
||||
|
||||
gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
|
||||
return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
|
||||
num_output_control_edges);
|
||||
}
|
||||
|
||||
DataType input_type(int i) const {
|
||||
DCHECK_LT(i, num_inputs);
|
||||
return static_cast<DataType>(input_type_base()[i]);
|
||||
}
|
||||
DataType output_type(int i) const {
|
||||
DCHECK_LT(i, num_outputs);
|
||||
return static_cast<DataType>(output_type_base()[i]);
|
||||
}
|
||||
|
||||
// Return array of per-output allocator attributes.
|
||||
const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
|
||||
|
||||
// Return array of expected input index from which each output should
|
||||
// be forwarded:
|
||||
// kNeverForward (-2) for DO NOT FORWARD (must allocate).
|
||||
// kNoReservation (-1) for no expected forwarding.
|
||||
// 0... for forward from that input.
|
||||
const int* forward_from() const { return forward_from_base(); }
|
||||
|
||||
string DebugString() const;
|
||||
|
||||
private:
|
||||
friend class GraphView;
|
||||
|
||||
NodeItem() {}
|
||||
|
||||
// Variable length section starts immediately after *this
|
||||
// (uint8 is enough for DataType).
|
||||
// EdgeInfo out_edges[num_output_edges];
|
||||
// ControlEdgeInfo out_control_edges[num_output_control_edges];
|
||||
// AllocatorAttributes output_attr[num_outputs];
|
||||
// int forward_from[num_outputs];
|
||||
// uint8 input_type[num_inputs];
|
||||
// uint8 output_type[num_outputs];
|
||||
|
||||
// Return pointer to variable length section.
|
||||
char* var() const {
|
||||
return const_cast<char*>(reinterpret_cast<const char*>(this) +
|
||||
sizeof(NodeItem));
|
||||
}
|
||||
|
||||
EdgeInfo* output_edge_base() const {
|
||||
return reinterpret_cast<EdgeInfo*>(var());
|
||||
}
|
||||
|
||||
ControlEdgeInfo* output_control_edge_base() const {
|
||||
return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
|
||||
num_output_edges);
|
||||
}
|
||||
|
||||
AllocatorAttributes* output_attr_base() const {
|
||||
return reinterpret_cast<AllocatorAttributes*>(
|
||||
var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) * num_output_control_edges);
|
||||
}
|
||||
int* forward_from_base() const {
|
||||
return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) *
|
||||
num_output_control_edges +
|
||||
sizeof(AllocatorAttributes) * num_outputs);
|
||||
}
|
||||
uint8* input_type_base() const {
|
||||
return reinterpret_cast<uint8*>(
|
||||
var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) * num_output_control_edges +
|
||||
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
|
||||
}
|
||||
uint8* output_type_base() const {
|
||||
return reinterpret_cast<uint8*>(
|
||||
var() + sizeof(EdgeInfo) * num_output_edges +
|
||||
sizeof(ControlEdgeInfo) * num_output_control_edges +
|
||||
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
|
||||
sizeof(uint8) * num_inputs);
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
|
||||
};
|
||||
|
||||
// Immutable view of a Graph organized for efficient execution.
|
||||
//
|
||||
// TODO(b/152651962): Add independent unit tests for this class.
|
||||
class GraphView {
|
||||
public:
|
||||
GraphView() : space_(nullptr) {}
|
||||
~GraphView();
|
||||
|
||||
Status Initialize(const Graph* g);
|
||||
Status SetAllocAttrs(const Graph* g, const Device* device);
|
||||
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
|
||||
|
||||
NodeItem* node(int32 id) const {
|
||||
DCHECK_GE(id, 0);
|
||||
DCHECK_LT(id, num_nodes_);
|
||||
uint32 offset = node_offsets_[id];
|
||||
return ((offset == kuint32max)
|
||||
? nullptr
|
||||
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
|
||||
}
|
||||
|
||||
int32 num_nodes() const { return num_nodes_; }
|
||||
|
||||
private:
|
||||
char* InitializeNode(char* ptr, const Node* n);
|
||||
size_t NodeItemBytes(const Node* n);
|
||||
|
||||
int32 num_nodes_ = 0;
|
||||
uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
|
||||
// node_offsets_[id] holds the byte offset for node w/ "id" in space_
|
||||
|
||||
char* space_; // NodeItem objects are allocated here
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
|
@ -45,13 +45,13 @@ auto* graph_run_time_usecs_histogram = monitoring::Sampler<0>::New(
|
||||
auto* graph_run_input_tensor_bytes = monitoring::Sampler<0>::New(
|
||||
{"/tensorflow/core/graph_run_input_tensor_bytes",
|
||||
"The size of input tensors in bytes."},
|
||||
// Power of 2 with bucket count 14 (256G)
|
||||
{monitoring::Buckets::Exponential(1, 4, 20)});
|
||||
// Power of 2 with bucket count 14 (256MB)
|
||||
{monitoring::Buckets::Exponential(1, 4, 14)});
|
||||
|
||||
auto* graph_run_output_tensor_bytes = monitoring::Sampler<0>::New(
|
||||
{"/tensorflow/core/graph_run_output_tensor_bytes",
|
||||
"The size of output tensors in bytes."},
|
||||
// Power of 2 with bucket count 14 (256G)
|
||||
// Power of 2 with bucket count 14 (256MB)
|
||||
{monitoring::Buckets::Exponential(1, 4, 14)});
|
||||
|
||||
auto* graph_unused_outputs = monitoring::Counter<1>::New(
|
||||
@ -72,8 +72,8 @@ auto* tf_data_bytes_fetched_counter = monitoring::Counter<0>::New(
|
||||
auto* tf_data_getnext_duration_counter = monitoring::Sampler<0>::New(
|
||||
{"/tensorflow/data/getnext_duration",
|
||||
"Microseconds spent fetching an element from tf.data Dataset iterator."},
|
||||
// Power of 2 with bucket count 14 (256G)
|
||||
{monitoring::Buckets::Exponential(1, 4, 20)});
|
||||
// Power of 2 with bucket count 10 (1024 ms)
|
||||
{monitoring::Buckets::Exponential(1, 2, 10)});
|
||||
|
||||
auto* tf_data_elements_counter = monitoring::Counter<1>::New(
|
||||
"/tensorflow/data/elements", "tf.data elements", "name");
|
||||
|
@ -110,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
||||
session_metadata_, this);
|
||||
}
|
||||
|
||||
DeviceMgr const* all_devices = device_mgr_;
|
||||
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
||||
all_devices = parent_->remote_device_mgr();
|
||||
}
|
||||
|
||||
for (auto d : all_devices->ListDevices()) {
|
||||
device_set_.AddDevice(d);
|
||||
}
|
||||
InitializeDeviceSet();
|
||||
}
|
||||
|
||||
/* static */
|
||||
@ -214,6 +207,18 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||
"function executions");
|
||||
}
|
||||
|
||||
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
|
||||
DeviceMgr const* all_devices = device_mgr_;
|
||||
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
|
||||
all_devices = parent_->remote_device_mgr();
|
||||
}
|
||||
|
||||
device_set_.reset(new DeviceSet);
|
||||
for (auto d : all_devices->ListDevices()) {
|
||||
device_set_->AddDevice(d);
|
||||
}
|
||||
}
|
||||
|
||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||
const string& device_name) const {
|
||||
Device* device = nullptr;
|
||||
@ -225,7 +230,8 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||
}
|
||||
const auto& iter = flr_map_->find(device);
|
||||
if (iter == flr_map_->end()) {
|
||||
LOG(ERROR) << "Could not find device: " << device_name;
|
||||
VLOG(1) << "Could not find device: " << device_name
|
||||
<< "in the local process.";
|
||||
return nullptr;
|
||||
}
|
||||
return iter->second.get();
|
||||
@ -678,7 +684,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
TF_RETURN_IF_ERROR(
|
||||
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
|
||||
TF_RETURN_IF_ERROR(PinArgsAndRets(
|
||||
options.input_devices, options.output_devices, device_set_, arg_nodes,
|
||||
options.input_devices, options.output_devices, *device_set_, arg_nodes,
|
||||
ret_nodes,
|
||||
options.config_proto.allow_soft_placement() ? default_device : nullptr));
|
||||
|
||||
@ -691,7 +697,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
|
||||
bool control_rets_updated = false;
|
||||
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
|
||||
device_set_, options.config_proto, &graph, &data->lib_def_,
|
||||
*device_set_, options.config_proto, &graph, &data->lib_def_,
|
||||
&control_ret_node_names, &control_rets_updated));
|
||||
|
||||
if (control_rets_updated) {
|
||||
@ -714,7 +720,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
optimization_options.session_options = &session_options;
|
||||
optimization_options.graph = &graph;
|
||||
optimization_options.flib_def = &data->lib_def_;
|
||||
optimization_options.device_set = &device_set_;
|
||||
optimization_options.device_set = device_set_.get();
|
||||
optimization_options.is_function_graph = true;
|
||||
|
||||
DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
|
||||
@ -725,7 +731,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
// exceptions/warnings in case where nested function call options are ignored.
|
||||
DumpGraph("Before calling Placer", graph.get());
|
||||
Placer placer(graph.get(), function_name, optimization_options.flib_def,
|
||||
&device_set_, default_device,
|
||||
device_set_.get(), default_device,
|
||||
options.config_proto.allow_soft_placement(),
|
||||
options.config_proto.log_device_placement());
|
||||
TF_RETURN_IF_ERROR(placer.Run());
|
||||
@ -741,7 +747,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
DumpGraph("Before running graph optimization fn", graph.get());
|
||||
Status status = options.optimize_graph_fn(
|
||||
std::move(ret_node_names), std::move(control_ret_node_names),
|
||||
&data->lib_def_, device_set_, cpu_device, &graph);
|
||||
&data->lib_def_, *device_set_, cpu_device, &graph);
|
||||
if (!status.ok()) {
|
||||
LOG(WARNING) << "Ignoring multi-device function optimization failure: "
|
||||
<< status.ToString();
|
||||
@ -765,7 +771,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
|
||||
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs));
|
||||
PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
|
||||
|
||||
for (const auto& pair : subgraphs) {
|
||||
DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
|
||||
@ -841,7 +847,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
|
||||
const string& target = pair.first;
|
||||
|
||||
const string& device_type =
|
||||
device_set_.FindDeviceByName(target)->device_type();
|
||||
device_set_->FindDeviceByName(target)->device_type();
|
||||
Graph* subgraph = pair.second.get();
|
||||
|
||||
status->Update(UpdateArgAndRetvalMetadata(
|
||||
@ -1258,12 +1264,18 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle(
|
||||
FunctionLibraryRuntime::DoneCallback
|
||||
ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
|
||||
std::vector<std::unique_ptr<CleanUpItem>>* items,
|
||||
FunctionLibraryRuntime::DoneCallback done,
|
||||
const Rendezvous* rendezvous) const {
|
||||
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
|
||||
const Rendezvous* created_rendezvous) const {
|
||||
return
|
||||
[this, items, done = std::move(done), rendezvous](const Status& status) {
|
||||
if (rendezvous) {
|
||||
rendezvous->Unref();
|
||||
[this, items, done = std::move(done), step_id,
|
||||
created_rendezvous](const Status& status) {
|
||||
if (created_rendezvous) {
|
||||
DCHECK(rendezvous_factory_);
|
||||
created_rendezvous->Unref();
|
||||
Status s = rendezvous_factory_.CleanUp(step_id);
|
||||
if (!s.ok()) {
|
||||
LOG(ERROR) << s;
|
||||
}
|
||||
}
|
||||
auto* local_status = new Status(status);
|
||||
CleanUp(items, [local_status, done](const Status& cleanup_status) {
|
||||
@ -1281,15 +1293,16 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
std::vector<Tensor>* rets,
|
||||
FunctionLibraryRuntime::DoneCallback done) const {
|
||||
FunctionLibraryRuntime::Options new_opts = opts;
|
||||
Rendezvous* rendezvous = nullptr;
|
||||
Rendezvous* created_rendezvous = nullptr;
|
||||
if (!opts.rendezvous) {
|
||||
if (rendezvous_factory_) {
|
||||
Status s = rendezvous_factory_(opts.step_id, device_mgr_, &rendezvous);
|
||||
Status s =
|
||||
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
}
|
||||
new_opts.rendezvous = rendezvous;
|
||||
new_opts.rendezvous = created_rendezvous;
|
||||
} else {
|
||||
done(
|
||||
errors::FailedPrecondition("The caller does not provide a rendezvous "
|
||||
@ -1301,7 +1314,8 @@ void ProcessFunctionLibraryRuntime::Run(
|
||||
}
|
||||
|
||||
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
|
||||
done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done), rendezvous);
|
||||
done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
|
||||
new_opts.step_id, created_rendezvous);
|
||||
bool multi_device;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
|
@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
DistributedFunctionLibraryRuntime* parent = nullptr,
|
||||
const CustomKernelCreator* custom_kernel_creator = nullptr,
|
||||
const SessionMetadata* session_metadata = nullptr,
|
||||
Rendezvous::Factory rendezvous_factory = nullptr);
|
||||
Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
|
||||
|
||||
virtual ~ProcessFunctionLibraryRuntime() {
|
||||
// Deleting the FunctionLibraryRuntime map will delete the function handles
|
||||
@ -191,7 +191,10 @@ class ProcessFunctionLibraryRuntime {
|
||||
|
||||
const DeviceMgr* device_mgr() { return device_mgr_; }
|
||||
|
||||
const DeviceSet* device_set() { return &device_set_; }
|
||||
const DeviceSet* device_set() { return device_set_.get(); }
|
||||
|
||||
// Initialize the set of local and remote devices for op device selection.
|
||||
void InitializeDeviceSet();
|
||||
|
||||
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
|
||||
|
||||
@ -294,7 +297,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
|
||||
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
|
||||
std::vector<std::unique_ptr<CleanUpItem>>* items,
|
||||
FunctionLibraryRuntime::DoneCallback done,
|
||||
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
|
||||
const Rendezvous* rendezvous) const;
|
||||
|
||||
DistributedFunctionLibraryRuntime* const parent_;
|
||||
@ -422,7 +425,7 @@ class ProcessFunctionLibraryRuntime {
|
||||
Env* const env_;
|
||||
const absl::optional<const ConfigProto> config_;
|
||||
const DeviceMgr* const device_mgr_;
|
||||
DeviceSet device_set_;
|
||||
std::unique_ptr<DeviceSet> device_set_;
|
||||
const FunctionLibraryDefinition* lib_def_;
|
||||
thread::ThreadPool* default_thread_pool_;
|
||||
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -122,10 +123,24 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
|
||||
/*thread_pool=*/nullptr, cluster_flr_.get(),
|
||||
/*custom_kernel_creator=*/nullptr, session_metadata,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}));
|
||||
Rendezvous::Factory{
|
||||
[this](const int64 step_id, const DeviceMgr* device_mgr,
|
||||
Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
if (rendezvous_ref_counts_.find(step_id) !=
|
||||
rendezvous_ref_counts_.end()) {
|
||||
rendezvous_ref_counts_[step_id]++;
|
||||
} else {
|
||||
rendezvous_ref_counts_[step_id] = 1;
|
||||
}
|
||||
return Status::OK();
|
||||
},
|
||||
[this](const int64 step_id) {
|
||||
CHECK(rendezvous_ref_counts_.find(step_id) !=
|
||||
rendezvous_ref_counts_.end());
|
||||
rendezvous_ref_counts_[step_id]--;
|
||||
return Status::OK();
|
||||
}}));
|
||||
}
|
||||
|
||||
Status Instantiate(
|
||||
@ -289,6 +304,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||
std::unique_ptr<TestClusterFLR> cluster_flr_;
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||
|
||||
// To ensure that we are cleaning up the rendezvous properly.
|
||||
std::unordered_map<int64, int> rendezvous_ref_counts_;
|
||||
};
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
|
||||
@ -362,6 +380,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
||||
test::ExpectTensorEqual<tstring>(
|
||||
y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:0"},
|
||||
TensorShape({})));
|
||||
EXPECT_EQ(1, rendezvous_ref_counts_.size());
|
||||
EXPECT_EQ(opts.step_id, rendezvous_ref_counts_.begin()->first);
|
||||
EXPECT_EQ(0, rendezvous_ref_counts_.begin()->second);
|
||||
}
|
||||
|
||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
||||
|
@ -61,13 +61,14 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
|
||||
/*thread_pool=*/nullptr, /*parent=*/nullptr,
|
||||
/*custom_kernel_creator=*/nullptr,
|
||||
/*session_metadata=*/nullptr,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
});
|
||||
Rendezvous::Factory{
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}});
|
||||
|
||||
string fetch_node = "";
|
||||
for (auto node : graph_def.node()) {
|
||||
for (const auto& node : graph_def.node()) {
|
||||
if (node.op() == "_Retval") {
|
||||
fetch_node = node.input(0);
|
||||
}
|
||||
|
@ -378,8 +378,8 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
|
||||
return errors::InvalidArgument("Invalid TensorProto: ",
|
||||
input.tensor().DebugString());
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||
std::move(tensor), nullptr, nullptr, eager_context, &handle));
|
||||
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
|
||||
nullptr, eager_context);
|
||||
op->AddInput(handle);
|
||||
}
|
||||
}
|
||||
@ -558,9 +558,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
|
||||
return errors::InvalidArgument("Unable to parse tensor proto");
|
||||
}
|
||||
|
||||
TensorHandle* tensor_handle = nullptr;
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
|
||||
std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle));
|
||||
TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
|
||||
std::move(tensor), nullptr, nullptr, eager_context);
|
||||
TensorHandle* copied_handle = nullptr;
|
||||
Device* device;
|
||||
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(
|
||||
|
@ -162,8 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
|
||||
in.op_device().empty() ? in.device() : in.op_device();
|
||||
TF_RETURN_IF_ERROR(
|
||||
parent_->FindDeviceFromName(device_name.c_str(), &device));
|
||||
TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
|
||||
in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
|
||||
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
|
||||
in.dtype(), device, parent_);
|
||||
TensorHandle::ResourceHandleInfo resource_handle_info;
|
||||
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
|
||||
&resource_handle_info.dtypes_and_shapes;
|
||||
|
@ -70,9 +70,8 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
|
||||
RemoteMgr remote_mgr(false, ctx_);
|
||||
Tensor t(DT_FLOAT, TensorShape({0}));
|
||||
|
||||
TensorHandle* handle;
|
||||
TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_,
|
||||
local_device_, ctx_, &handle));
|
||||
TensorHandle* handle = TensorHandle::CreateLocalHandle(
|
||||
std::move(t), local_device_, local_device_, ctx_);
|
||||
const uint64 op_id = 2;
|
||||
const int output_num = 3;
|
||||
TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
|
||||
@ -91,10 +90,9 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
|
||||
|
||||
const uint64 op_id = 3;
|
||||
const int output_num = 1;
|
||||
TensorHandle* handle;
|
||||
TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle(
|
||||
TensorHandle* handle = TensorHandle::CreateUnshapedRemoteHandle(
|
||||
op_id, output_num,
|
||||
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle));
|
||||
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_);
|
||||
RemoteTensorHandle remote_handle;
|
||||
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
|
||||
handle, &remote_handle, remote_device_, remote_device_->name()));
|
||||
|
@ -67,7 +67,7 @@ GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
|
||||
}
|
||||
|
||||
GraphMgr::~GraphMgr() {
|
||||
for (auto p : table_) p.second->Unref();
|
||||
for (const auto& p : table_) p.second->Unref();
|
||||
}
|
||||
|
||||
GraphMgr::Item::~Item() {
|
||||
@ -141,13 +141,18 @@ Status GraphMgr::InitItem(
|
||||
gdef.versions().producer(), item->lib_def.get(),
|
||||
graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
|
||||
/*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr,
|
||||
[this, session](const int64 step_id, const DeviceMgr*,
|
||||
Rendezvous** r) -> Status {
|
||||
auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
|
||||
TF_RETURN_IF_ERROR(remote_r->Initialize(session));
|
||||
*r = remote_r;
|
||||
return Status::OK();
|
||||
}));
|
||||
Rendezvous::Factory{
|
||||
[this, session](const int64 step_id, const DeviceMgr*,
|
||||
Rendezvous** r) -> Status {
|
||||
auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
|
||||
TF_RETURN_IF_ERROR(remote_r->Initialize(session));
|
||||
*r = remote_r;
|
||||
return Status::OK();
|
||||
},
|
||||
[this](const int64 step_id) {
|
||||
this->worker_env_->rendezvous_mgr->Cleanup(step_id);
|
||||
return Status::OK();
|
||||
}}));
|
||||
|
||||
// Constructs the graph out of "gdef".
|
||||
Graph graph(OpRegistry::Global());
|
||||
|
@ -387,45 +387,28 @@ void OpKernelContext::SetStatus(const Status& status) {
|
||||
}
|
||||
|
||||
Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
name,
|
||||
"' when single-valued input was "
|
||||
"expected");
|
||||
}
|
||||
if (input_is_ref(start)) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_input_index(name, &index));
|
||||
if (input_is_ref(index)) {
|
||||
return errors::InvalidArgument("OpKernel used ref input name '", name,
|
||||
"' when non-ref input was expected");
|
||||
}
|
||||
*tensor = (*params_->inputs)[start].tensor;
|
||||
*tensor = (*params_->inputs)[index].tensor;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
name,
|
||||
"' when single-valued input was "
|
||||
"expected");
|
||||
}
|
||||
const TensorValue& value((*params_->inputs)[start]);
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_input_index(name, &index));
|
||||
const TensorValue& value((*params_->inputs)[index]);
|
||||
*dtype = value.dtype();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
name,
|
||||
"' when single-valued input was expected");
|
||||
}
|
||||
*out_mutex = input_ref_mutex(start);
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_input_index(name, &index));
|
||||
*out_mutex = input_ref_mutex(index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -497,23 +480,9 @@ bool OpKernelContext::forward_input_to_output_with_shape(
|
||||
Status OpKernelContext::forward_input_to_output_with_shape(
|
||||
StringPiece input_name, StringPiece output_name,
|
||||
const TensorShape& output_shape, Tensor** output) {
|
||||
int input_index, output_index, stop;
|
||||
TF_RETURN_IF_ERROR(
|
||||
params_->op_kernel->InputRange(input_name, &input_index, &stop));
|
||||
if (stop != input_index + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
input_name,
|
||||
"' when single-valued input was "
|
||||
"expected");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
params_->op_kernel->OutputRange(output_name, &output_index, &stop));
|
||||
if (stop != output_index + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued output name '",
|
||||
output_name,
|
||||
"' when single-valued output was "
|
||||
"expected");
|
||||
}
|
||||
int input_index, output_index;
|
||||
TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
|
||||
TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
|
||||
if (!forward_input_to_output_with_shape(input_index, output_index,
|
||||
output_shape, output)) {
|
||||
return errors::FailedPrecondition("OpKernel could not forward input '",
|
||||
@ -621,23 +590,18 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) {
|
||||
|
||||
Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
|
||||
bool lock_held) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
name,
|
||||
"' when single-valued input was expected");
|
||||
}
|
||||
if (!input_is_ref(start)) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_input_index(name, &index));
|
||||
if (!input_is_ref(index)) {
|
||||
return errors::InvalidArgument("OpKernel used non-ref input name '", name,
|
||||
"' when ref input was expected");
|
||||
}
|
||||
// return a copy of the Ref acquired while holding the mutex
|
||||
if (lock_held) {
|
||||
*tensor = *(*params_->inputs)[start].tensor;
|
||||
*tensor = *(*params_->inputs)[index].tensor;
|
||||
} else {
|
||||
tf_shared_lock l(*input_ref_mutex(start));
|
||||
*tensor = *(*params_->inputs)[start].tensor;
|
||||
tf_shared_lock l(*input_ref_mutex(index));
|
||||
*tensor = *(*params_->inputs)[index].tensor;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -645,18 +609,13 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
|
||||
Status OpKernelContext::replace_ref_input(StringPiece name,
|
||||
const Tensor& tensor,
|
||||
bool lock_held) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
name,
|
||||
"' when single-valued input was expected");
|
||||
}
|
||||
if (!input_is_ref(start)) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_input_index(name, &index));
|
||||
if (!input_is_ref(index)) {
|
||||
return errors::InvalidArgument("OpKernel used immutable input name '", name,
|
||||
"' when ref input was expected");
|
||||
}
|
||||
replace_ref_input(start, tensor, lock_held);
|
||||
replace_ref_input(index, tensor, lock_held);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -888,7 +847,22 @@ Status OpKernelContext::allocate_persistent(DataType type,
|
||||
return s;
|
||||
}
|
||||
|
||||
Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
|
||||
Status OpKernelContext::get_input_index(StringPiece name,
|
||||
int* out_index) const {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued input name '",
|
||||
name,
|
||||
"' when single-valued input was "
|
||||
"expected");
|
||||
}
|
||||
*out_index = start;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpKernelContext::get_output_index(StringPiece name,
|
||||
int* out_index) const {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
@ -897,22 +871,31 @@ Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
|
||||
"' when single-valued output was "
|
||||
"expected");
|
||||
}
|
||||
set_output(start, tensor);
|
||||
*out_index = start;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void OpKernelContext::set_output(int index, const Tensor& tensor) {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, outputs_.size());
|
||||
const DataType type = params_->op_kernel->output_type(index);
|
||||
CHECK(!IsRefType(type));
|
||||
CHECK_EQ(mutable_output(index), nullptr);
|
||||
Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_output_index(name, &index));
|
||||
set_output(index, tensor);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_output_index(name, &index));
|
||||
set_output(index, std::move(tensor));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
|
||||
int index, const Tensor& tensor) {
|
||||
bool allocate_and_copy = false;
|
||||
const bool never_forward =
|
||||
(params_->forward_from_array != nullptr &&
|
||||
params_->forward_from_array[index] == Params::kNeverForward);
|
||||
if (never_forward) {
|
||||
if (TF_PREDICT_FALSE(never_forward)) {
|
||||
maybe_initialize_scope_id_set();
|
||||
if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
|
||||
allocated_scope_ids_->end()) {
|
||||
@ -929,7 +912,7 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
if (allocate_and_copy) {
|
||||
if (TF_PREDICT_FALSE(allocate_and_copy)) {
|
||||
// This output was marked to not be forwarded either during graph
|
||||
// construction or grappler passes. Force an allocation and copy input to
|
||||
// output.
|
||||
@ -939,31 +922,59 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
|
||||
<< params_->forward_from_array[index] << " alloc_attr.scope_id "
|
||||
<< output_alloc_attr(index).scope_id;
|
||||
auto new_tensor = MakeUnique<Tensor>();
|
||||
Status s = allocate_tensor(type, tensor.shape(), new_tensor.get(),
|
||||
Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
|
||||
output_alloc_attr(index));
|
||||
TF_CHECK_OK(s);
|
||||
device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
|
||||
op_device_context(), [](const Status&) {});
|
||||
outputs_[index] = TensorValue(new_tensor.release());
|
||||
} else {
|
||||
}
|
||||
return allocate_and_copy;
|
||||
}
|
||||
|
||||
void OpKernelContext::maybe_track_allocations_for_set_output(
|
||||
const Tensor& tensor) {
|
||||
if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
|
||||
DCHECK(tracking_state_);
|
||||
mutex_lock l(tracking_state_->stats_mu);
|
||||
const auto it = std::find_if(
|
||||
tracking_state_->temp_tensor_buffer_and_size.begin(),
|
||||
tracking_state_->temp_tensor_buffer_and_size.end(),
|
||||
[&tensor](const std::pair<const void*, int64>& e) {
|
||||
return e.first ==
|
||||
static_cast<const void*>(tensor.tensor_data().data());
|
||||
});
|
||||
if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
|
||||
tracking_state_->temp_memory_allocated -= it->second;
|
||||
tracking_state_->temp_tensor_buffer_and_size.erase(it);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpKernelContext::set_output(int index, const Tensor& tensor) {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, outputs_.size());
|
||||
const DataType type = params_->op_kernel->output_type(index);
|
||||
CHECK(!IsRefType(type));
|
||||
CHECK_EQ(outputs_[index].tensor, nullptr);
|
||||
if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
|
||||
// Input can be forwarded to output; incref on `tensor` and set output at
|
||||
// `index` to this tensor.
|
||||
outputs_[index] = TensorValue(new Tensor(tensor));
|
||||
if (track_allocations() && tensor.TotalBytes() > 0) {
|
||||
DCHECK(tracking_state_);
|
||||
mutex_lock l(tracking_state_->stats_mu);
|
||||
const auto it = std::find_if(
|
||||
tracking_state_->temp_tensor_buffer_and_size.begin(),
|
||||
tracking_state_->temp_tensor_buffer_and_size.end(),
|
||||
[&tensor](const std::pair<const void*, int64>& e) {
|
||||
return e.first ==
|
||||
static_cast<const void*>(tensor.tensor_data().data());
|
||||
});
|
||||
if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
|
||||
tracking_state_->temp_memory_allocated -= it->second;
|
||||
tracking_state_->temp_tensor_buffer_and_size.erase(it);
|
||||
}
|
||||
}
|
||||
maybe_track_allocations_for_set_output(*outputs_[index].tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void OpKernelContext::set_output(int index, Tensor&& tensor) {
|
||||
CHECK_GE(index, 0);
|
||||
CHECK_LT(index, outputs_.size());
|
||||
const DataType type = params_->op_kernel->output_type(index);
|
||||
CHECK(!IsRefType(type));
|
||||
CHECK_EQ(outputs_[index].tensor, nullptr);
|
||||
if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
|
||||
// Input can be forwarded to output; set output at `index` to this tensor.
|
||||
outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
|
||||
maybe_track_allocations_for_set_output(*outputs_[index].tensor);
|
||||
}
|
||||
}
|
||||
|
||||
@ -977,28 +988,16 @@ void OpKernelContext::set_output_ref(int index, mutex* mu,
|
||||
|
||||
Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
|
||||
Tensor* tensor_for_ref) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued output name '",
|
||||
name,
|
||||
"' when single-valued output was "
|
||||
"expected");
|
||||
}
|
||||
set_output_ref(start, mu, tensor_for_ref);
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_output_index(name, &index));
|
||||
set_output_ref(index, mu, tensor_for_ref);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
|
||||
int start, stop;
|
||||
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
|
||||
if (stop != start + 1) {
|
||||
return errors::InvalidArgument("OpKernel used list-valued output name '",
|
||||
name,
|
||||
"' when single-valued output was "
|
||||
"expected");
|
||||
}
|
||||
*tensor = mutable_output(start);
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(get_output_index(name, &index));
|
||||
*tensor = mutable_output(index);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -492,6 +492,7 @@ class OpOutputList {
|
||||
DataType expected_output_dtype(int i) const;
|
||||
Status allocate(int i, const TensorShape& shape, Tensor** output);
|
||||
void set(int i, const Tensor& tensor);
|
||||
void set(int i, Tensor&& tensor);
|
||||
void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
|
||||
int size() const { return stop_ - start_; }
|
||||
Iterator begin() const { return Iterator(this, 0); }
|
||||
@ -1031,6 +1032,9 @@ class OpKernelContext {
|
||||
// REQUIRES: 'tensor' must have the same MemoryType as
|
||||
// output_memory_types[index]. See comment above.
|
||||
Status set_output(StringPiece name, const Tensor& tensor);
|
||||
Status set_output(StringPiece name, Tensor&& tensor);
|
||||
void set_output(int index, const Tensor& tensor);
|
||||
void set_output(int index, Tensor&& tensor);
|
||||
|
||||
// To output a reference. Caller retains ownership of mu and tensor_for_ref,
|
||||
// and they must outlive all uses within the step. See comment above.
|
||||
@ -1198,7 +1202,6 @@ class OpKernelContext {
|
||||
// The following functions all have versions that return Status
|
||||
// to capture error conditions, and are strongly preferred.
|
||||
Tensor* mutable_output(int index);
|
||||
void set_output(int index, const Tensor& tensor);
|
||||
mutex* input_ref_mutex(int index);
|
||||
void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
|
||||
TensorValue release_output(int index);
|
||||
@ -1274,6 +1277,16 @@ class OpKernelContext {
|
||||
Tensor* out_tensor, AllocatorAttributes allocator_attr,
|
||||
const AllocationAttributes& allocation_attr);
|
||||
|
||||
// Helpers for `set_output()`.
|
||||
|
||||
// Returns `true` if the tensor was copied into an allocated output.
|
||||
bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
|
||||
|
||||
void maybe_track_allocations_for_set_output(const Tensor& tensor);
|
||||
|
||||
Status get_input_index(StringPiece name, int* out_index) const;
|
||||
Status get_output_index(StringPiece name, int* out_index) const;
|
||||
|
||||
// Initialize the allocated_scope_ids_ set the first time this method is
|
||||
// called.
|
||||
void maybe_initialize_scope_id_set();
|
||||
@ -1704,6 +1717,12 @@ inline void OpOutputList::set(int i, const Tensor& tensor) {
|
||||
ctx_->set_output(start_ + i, tensor);
|
||||
}
|
||||
|
||||
inline void OpOutputList::set(int i, Tensor&& tensor) {
|
||||
DCHECK_GE(i, 0);
|
||||
DCHECK_LT(i, stop_ - start_);
|
||||
ctx_->set_output(start_ + i, std::move(tensor));
|
||||
}
|
||||
|
||||
inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
|
||||
DCHECK_GE(i, 0);
|
||||
DCHECK_LT(i, stop_ - start_);
|
||||
|
@ -129,8 +129,43 @@ class RendezvousInterface {
|
||||
// threads with no clear owner.
|
||||
class Rendezvous : public RendezvousInterface, public core::RefCounted {
|
||||
public:
|
||||
using Factory =
|
||||
std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>;
|
||||
class Factory {
|
||||
public:
|
||||
// Default to a factory that evaluates to false.
|
||||
Factory() : valid_(false) {}
|
||||
|
||||
Factory(std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>
|
||||
create_fn,
|
||||
std::function<Status(const int64)> cleanup_fn)
|
||||
: valid_(true),
|
||||
create_fn_(std::move(create_fn)),
|
||||
cleanup_fn_(std::move(cleanup_fn)) {}
|
||||
|
||||
// If no clean up fn is provided, just put in a dummy.
|
||||
// For backwards compatibility.
|
||||
explicit Factory(
|
||||
std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>
|
||||
create_fn)
|
||||
: valid_(true),
|
||||
create_fn_(std::move(create_fn)),
|
||||
cleanup_fn_([](const int64 step_id) { return Status::OK(); }) {}
|
||||
|
||||
explicit operator bool() const { return valid_; }
|
||||
|
||||
Status operator()(const int64 step_id, const DeviceMgr* device_mgr,
|
||||
Rendezvous** rendez) const {
|
||||
return create_fn_(step_id, device_mgr, rendez);
|
||||
}
|
||||
|
||||
Status CleanUp(const int64 step_id) const { return cleanup_fn_(step_id); }
|
||||
|
||||
private:
|
||||
bool valid_;
|
||||
std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>
|
||||
create_fn_;
|
||||
std::function<Status(const int64)> cleanup_fn_;
|
||||
};
|
||||
|
||||
// Constructs a rendezvous key for the tensor of "name" sent from
|
||||
// "src_device" to "dst_device". The tensor is generated in the frame
|
||||
// and iteration specified by "frame_iter".
|
||||
|
@ -37,39 +37,49 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
void RecordPaddingSize(int32 padding_size) {
|
||||
void RecordPaddingSize(int32 padding_size, const string& model_name) {
|
||||
static tensorflow::monitoring::PercentileSamplerCell* cell =
|
||||
tensorflow::monitoring::PercentileSampler<0>::New(
|
||||
{"/tensorflow/serving/batching/padding_size",
|
||||
"Tracks the padding size distribution on batches."},
|
||||
tensorflow::monitoring::PercentileSampler<1>::New(
|
||||
{"/tensorflow/serving/batching/padding_size", "model_name",
|
||||
"Tracks the padding size distribution on batches by model_name (if "
|
||||
"available)."},
|
||||
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
|
||||
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber)
|
||||
->GetCell();
|
||||
->GetCell(model_name);
|
||||
cell->Add(static_cast<double>(padding_size));
|
||||
}
|
||||
|
||||
void RecordInputBatchSize(int32 batch_size) {
|
||||
void RecordInputBatchSize(int32 batch_size, const string& model_name) {
|
||||
static tensorflow::monitoring::PercentileSamplerCell* cell =
|
||||
tensorflow::monitoring::PercentileSampler<0>::New(
|
||||
{"/tensorflow/serving/batching/input_batch_size",
|
||||
"Tracks the batch size distribution on the inputs."},
|
||||
tensorflow::monitoring::PercentileSampler<1>::New(
|
||||
{"/tensorflow/serving/batching/input_batch_size", "model_name",
|
||||
"Tracks the batch size distribution on the inputs by model_name (if "
|
||||
"available)."},
|
||||
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
|
||||
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber)
|
||||
->GetCell();
|
||||
->GetCell(model_name);
|
||||
cell->Add(static_cast<double>(batch_size));
|
||||
}
|
||||
|
||||
void RecordBatchDelayMs(int64 batch_delay_ms) {
|
||||
void RecordBatchDelayMs(int64 batch_delay_ms, const string& model_name) {
|
||||
static monitoring::PercentileSamplerCell* cell =
|
||||
monitoring::PercentileSampler<0>::New(
|
||||
{"/tensorflow/serving/batching/batch_delay_ms",
|
||||
"Tracks the batching delay for inputs."},
|
||||
monitoring::PercentileSampler<1>::New(
|
||||
{"/tensorflow/serving/batching/batch_delay_ms", "model_name",
|
||||
"Tracks the batching delay for inputs by model_name (if "
|
||||
"available)."},
|
||||
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
|
||||
/*max_samples=*/1024, monitoring::UnitOfMeasure::kTime)
|
||||
->GetCell();
|
||||
->GetCell(model_name);
|
||||
cell->Add(static_cast<double>(batch_delay_ms));
|
||||
}
|
||||
|
||||
const string& GetModelName(OpKernelContext* ctx) {
|
||||
static string* kModelNameUnset = new string("model_name_unset");
|
||||
if (!ctx->session_metadata()) return *kModelNameUnset;
|
||||
if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
|
||||
return ctx->session_metadata()->name();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
@ -303,7 +313,7 @@ class BatchResource : public ResourceBase {
|
||||
}
|
||||
batch_components->inputs.push_back(tensor);
|
||||
}
|
||||
RecordInputBatchSize(tensors[0].shape().dim_size(0));
|
||||
RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context));
|
||||
OpInputList captured_tensors;
|
||||
const auto captured_status =
|
||||
context->input_list("captured_tensors", &captured_tensors);
|
||||
@ -386,7 +396,7 @@ class BatchResource : public ResourceBase {
|
||||
|
||||
const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
|
||||
const int padding_amount = padded_batch_size - batch.size();
|
||||
RecordPaddingSize(padding_amount);
|
||||
RecordPaddingSize(padding_amount, GetModelName(context));
|
||||
|
||||
// All tasks should have the same number of input edges.
|
||||
const int num_inputs = batch.task(0).inputs.size();
|
||||
@ -570,8 +580,10 @@ class BatchResource : public ResourceBase {
|
||||
args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
|
||||
|
||||
uint64 current_time = EnvTime::NowNanos();
|
||||
const string& model_name = GetModelName(last_task_context);
|
||||
for (int i = 0; i < batch->num_tasks(); ++i) {
|
||||
RecordBatchDelayMs((current_time - batch->task(i).start_time) * 1e-6);
|
||||
RecordBatchDelayMs((current_time - batch->task(i).start_time) * 1e-6,
|
||||
model_name);
|
||||
}
|
||||
// Releases the cleanup method here, because the callback of the function
|
||||
// library runtime will handle it now.
|
||||
|
@ -405,10 +405,11 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
|
||||
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
|
||||
/*parent=*/nullptr, /*custom_kernel_creator=*/nullptr,
|
||||
/*session_metadata=*/nullptr,
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
});
|
||||
Rendezvous::Factory{
|
||||
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
|
||||
*r = new IntraProcessRendezvous(device_mgr);
|
||||
return Status::OK();
|
||||
}});
|
||||
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
|
||||
if (thread_pool_ == nullptr) {
|
||||
runner_ = [](const std::function<void()>& fn) { fn(); };
|
||||
@ -548,8 +549,7 @@ Status DatasetOpsTestBase::AddDatasetInput(
|
||||
inputs->size(), " vs. ", input_types.size());
|
||||
}
|
||||
bool is_ref = IsRefType(input_types[inputs->size()]);
|
||||
std::unique_ptr<Tensor> input =
|
||||
absl::make_unique<Tensor>(allocator_, dtype, shape);
|
||||
auto input = absl::make_unique<Tensor>(allocator_, dtype, shape);
|
||||
|
||||
if (is_ref) {
|
||||
DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
|
||||
|
@ -50,7 +50,7 @@ void ArgOp::Compute(OpKernelContext* ctx) {
|
||||
errors::InvalidArgument("Type mismatch: actual ",
|
||||
DataTypeString(val.dtype()),
|
||||
" vs. expect ", DataTypeString(dtype_)));
|
||||
ctx->set_output(0, val);
|
||||
ctx->set_output(0, std::move(val));
|
||||
}
|
||||
|
||||
RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
@ -279,7 +279,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
|
||||
" tensor(s), but get ", rets->size(), " tensor(s) instead."));
|
||||
} else {
|
||||
for (size_t i = 0; i < rets->size(); ++i) {
|
||||
ctx->set_output(i, (*rets)[i]);
|
||||
ctx->set_output(i, std::move((*rets)[i]));
|
||||
}
|
||||
}
|
||||
delete rets;
|
||||
@ -413,7 +413,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
ctx->SetStatus(status);
|
||||
} else {
|
||||
for (size_t i = 0; i < rets->size(); ++i) {
|
||||
ctx->set_output(i, (*rets)[i]);
|
||||
ctx->set_output(i, std::move((*rets)[i]));
|
||||
}
|
||||
}
|
||||
delete rets;
|
||||
|
@ -100,6 +100,12 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
|
||||
errors::InvalidArgument(
|
||||
"Number of targets * number of classes must be less than INT_MAX"));
|
||||
|
||||
if (num_targets == 0 || num_classes == 0) {
|
||||
// Result is empty, so shortcut the rest of the function to avoid
|
||||
// launching kernels with empty input.
|
||||
return;
|
||||
}
|
||||
|
||||
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
|
||||
Tensor predictions_mask;
|
||||
OP_REQUIRES_OK(
|
||||
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/framework/node_properties.h"
|
||||
#ifdef GOOGLE_CUDA
|
||||
#define EIGEN_USE_GPU
|
||||
#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
|
||||
@ -137,11 +138,16 @@ Status OpsTestBase::InitOp() {
|
||||
}
|
||||
|
||||
Status OpsTestBase::InitOpWithGraphVersion(int graph_def_version) {
|
||||
Status status;
|
||||
kernel_ = CreateOpKernel(device_type_, device_, allocator(), node_def_,
|
||||
graph_def_version, &status);
|
||||
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
|
||||
return status;
|
||||
std::shared_ptr<const NodeProperties> props;
|
||||
TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
|
||||
node_def_, OpRegistry::Global(), &props));
|
||||
OpKernel* kernel;
|
||||
TF_RETURN_IF_ERROR(CreateOpKernel(
|
||||
device_type_, device_, allocator(), /*flib=*/nullptr,
|
||||
device_->resource_manager(), props, graph_def_version, &kernel));
|
||||
kernel_.reset(kernel);
|
||||
input_types_ = kernel_->input_types();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status OpsTestBase::RunOpKernel() {
|
||||
|
@ -224,10 +224,8 @@ REGISTER_CPU(complex128)
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
@ -362,10 +362,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
|
||||
|
||||
REGISTER_GPU(GPU, float)
|
||||
REGISTER_GPU(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(GPU, complex64)
|
||||
REGISTER_GPU(GPU, complex128)
|
||||
#endif
|
||||
|
||||
namespace functor {
|
||||
|
||||
|
@ -538,8 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
|
||||
|
||||
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
|
||||
|
||||
if (b_outer_dim == 1) {
|
||||
bool use_matrix_vector_multiply = (b_outer_dim == 1);
|
||||
#if TENSORFLOW_USE_ROCM
|
||||
// ROCm hipsparse does not implement csrmv with transposed input a
|
||||
use_matrix_vector_multiply =
|
||||
use_matrix_vector_multiply && !this->transpose_a_;
|
||||
#endif
|
||||
if (use_matrix_vector_multiply) {
|
||||
// Call matrix-vector multiply if b is a vector.
|
||||
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
|
||||
2);
|
||||
|
@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel {
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -120,10 +120,8 @@ REGISTER(CPU, complex128)
|
||||
|
||||
REGISTER(GPU, float)
|
||||
REGISTER(GPU, double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER(GPU, complex64)
|
||||
REGISTER(GPU, complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER
|
||||
|
||||
@ -141,10 +139,8 @@ namespace functor {
|
||||
DECLARE_GPU_SPEC(int32);
|
||||
DECLARE_GPU_SPEC(float);
|
||||
DECLARE_GPU_SPEC(double);
|
||||
#if GOOGLE_CUDA
|
||||
DECLARE_GPU_SPEC(complex64);
|
||||
DECLARE_GPU_SPEC(complex128);
|
||||
#endif
|
||||
|
||||
#undef DECLARE_GPU_SPEC
|
||||
} // namespace functor
|
||||
|
@ -328,10 +328,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
|
||||
|
||||
REGISTER_GPU(float)
|
||||
REGISTER_GPU(double)
|
||||
#if GOOGLE_CUDA
|
||||
REGISTER_GPU(complex64)
|
||||
REGISTER_GPU(complex128)
|
||||
#endif
|
||||
|
||||
#undef REGISTER_GPU
|
||||
|
||||
|
@ -50,15 +50,11 @@ class LegacyVar : public ResourceBase {
|
||||
VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
|
||||
dtype_ = RemoveRefType(context->output_type(0));
|
||||
OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(),
|
||||
true /* use name() */));
|
||||
}
|
||||
|
||||
void VariableOp::Compute(OpKernelContext* ctx) {
|
||||
mutex_lock l(init_mu_);
|
||||
if (!initialized_) {
|
||||
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
|
||||
true /* use name() */));
|
||||
initialized_ = true;
|
||||
}
|
||||
auto creator = [this](LegacyVar** var) {
|
||||
*var = new LegacyVar(dtype_);
|
||||
(*var)->tensor()->set_shape(shape_);
|
||||
|
@ -36,10 +36,7 @@ class VariableOp : public OpKernel {
|
||||
private:
|
||||
DataType dtype_;
|
||||
TensorShape shape_;
|
||||
|
||||
mutex init_mu_;
|
||||
ContainerInfo cinfo_ TF_GUARDED_BY(init_mu_);
|
||||
bool initialized_ TF_GUARDED_BY(init_mu_){false};
|
||||
ContainerInfo cinfo_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(VariableOp);
|
||||
};
|
||||
|
@ -36,8 +36,22 @@ struct ProfilerOptions {
|
||||
// DeviceType::kTpu: only CPU/TPU will be profiled.
|
||||
DeviceType device_type = DeviceType::kUnspecified;
|
||||
|
||||
// Inexpensive ops are not traced by default.
|
||||
int host_tracer_level = 2;
|
||||
// Levels of host tracing:
|
||||
// - Level 0 is used to disable host traces.
|
||||
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
|
||||
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
|
||||
// level program execution details (expensive TF ops, XLA ops, etc).
|
||||
// This is the default.
|
||||
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
|
||||
// (low-level) program execution details (cheap TF ops, etc).
|
||||
uint32 host_tracer_level = 2;
|
||||
|
||||
// Levels of device tracing:
|
||||
// - Level 0 is used to disable device traces.
|
||||
// - Level 1 is used to enable device traces.
|
||||
// - More levels might be defined for specific device for controlling the
|
||||
// verbosity of the trace.
|
||||
uint32 device_tracer_level = 1;
|
||||
|
||||
// Whether to enable python function calls tracer.
|
||||
bool enable_python_tracer = false;
|
||||
|
@ -14,11 +14,37 @@ service ProfilerService {
|
||||
}
|
||||
|
||||
message ProfileOptions {
|
||||
// Some default value of option are not proto3 default value. Use this version
|
||||
// to determine if we should use default option value instead of proto3
|
||||
// default value.
|
||||
uint32 version = 5;
|
||||
|
||||
// We don't collect the dataset ops by default for better trace-viewer
|
||||
// scalability. The caller can mannually set this field to include the ops.
|
||||
bool include_dataset_ops = 1;
|
||||
|
||||
// next-field: 2
|
||||
// Levels of host tracing: (version >= 1)
|
||||
// - Level 0 is used to disable host traces.
|
||||
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
|
||||
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
|
||||
// level program execution details (expensive TF ops, XLA ops, etc).
|
||||
// This is the default.
|
||||
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
|
||||
// (low-level) program execution details (cheap TF ops, etc).
|
||||
uint32 host_tracer_level = 2;
|
||||
|
||||
// Levels of device tracing: (version >= 1)
|
||||
// - Level 0 is used to disable device traces.
|
||||
// - Level 1 is used to enable device traces.
|
||||
// - More levels might be defined for specific device for controlling the
|
||||
// verbosity of the trace.
|
||||
uint32 device_tracer_level = 3;
|
||||
|
||||
// Whether enable python function calls tracing. Runtime overhead ensues if
|
||||
// enabled. Default off. (version >= 1)
|
||||
uint32 python_tracer_level = 4;
|
||||
|
||||
// next-field: 6
|
||||
}
|
||||
|
||||
message ToolRequestOptions {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/env_time.h"
|
||||
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
|
||||
#include "tensorflow/core/profiler/internal/profiler_interface.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
#include "tensorflow/core/util/ptr_util.h"
|
||||
@ -51,7 +52,8 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
||||
::grpc::Status Profile(::grpc::ServerContext* ctx, const ProfileRequest* req,
|
||||
ProfileResponse* response) override {
|
||||
VLOG(1) << "Received a profile request: " << req->DebugString();
|
||||
std::unique_ptr<ProfilerSession> profiler = ProfilerSession::Create();
|
||||
std::unique_ptr<ProfilerSession> profiler =
|
||||
ProfilerSession::Create(GetOptions(req->opts()));
|
||||
Status status = profiler->Status();
|
||||
if (!status.ok()) {
|
||||
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
|
||||
@ -74,6 +76,19 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
|
||||
|
||||
return ::grpc::Status::OK;
|
||||
}
|
||||
|
||||
private:
|
||||
profiler::ProfilerOptions GetOptions(const tensorflow::ProfileOptions& opts) {
|
||||
profiler::ProfilerOptions options;
|
||||
if (opts.version()) {
|
||||
options.host_tracer_level = opts.host_tracer_level();
|
||||
options.device_tracer_level = opts.device_tracer_level();
|
||||
options.enable_python_tracer = opts.python_tracer_level() > 0;
|
||||
} else {
|
||||
// use default options value;
|
||||
}
|
||||
return options;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
@ -122,6 +122,8 @@ const StatTypeMap& GetStatTypeMap() {
|
||||
{"fragmentation", kFragmentation},
|
||||
{"peak_bytes_in_use", kPeakBytesInUse},
|
||||
{"requested_bytes", kRequestedBytes},
|
||||
{"allocation_bytes", kAllocationBytes},
|
||||
{"addr", kAddress},
|
||||
{"shape", kTensorShapes},
|
||||
// Device trace arguments.
|
||||
{"device_id", kDeviceId},
|
||||
|
@ -113,6 +113,8 @@ enum StatType {
|
||||
kFragmentation,
|
||||
kPeakBytesInUse,
|
||||
kRequestedBytes,
|
||||
kAllocationBytes,
|
||||
kAddress,
|
||||
kTensorShapes,
|
||||
// Device trace arguments.
|
||||
kDeviceId,
|
||||
|
@ -12022,7 +12022,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2
|
||||
//
|
||||
// value: The cropped area of the image must have an aspect ratio =
|
||||
// width / height within this range.
|
||||
// If not specified, defaults to {f:0.75 f:1.33}
|
||||
// If not specified, defaults to {f:0.75 f:1.33}
|
||||
func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["aspect_ratio_range"] = value
|
||||
@ -12033,7 +12033,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
|
||||
//
|
||||
// value: The cropped area of the image must contain a fraction of the
|
||||
// supplied image within this range.
|
||||
// If not specified, defaults to {f:0.05 f:1}
|
||||
// If not specified, defaults to {f:0.05 f:1}
|
||||
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["area_range"] = value
|
||||
@ -12251,7 +12251,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo
|
||||
//
|
||||
// value: The cropped area of the image must have an aspect ratio =
|
||||
// width / height within this range.
|
||||
// If not specified, defaults to {f:0.75 f:1.33}
|
||||
// If not specified, defaults to {f:0.75 f:1.33}
|
||||
func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["aspect_ratio_range"] = value
|
||||
@ -12262,7 +12262,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
|
||||
//
|
||||
// value: The cropped area of the image must contain a fraction of the
|
||||
// supplied image within this range.
|
||||
// If not specified, defaults to {f:0.05 f:1}
|
||||
// If not specified, defaults to {f:0.05 f:1}
|
||||
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["area_range"] = value
|
||||
@ -19038,7 +19038,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr {
|
||||
// ImageSummaryBadColor sets the optional bad_color attribute to value.
|
||||
//
|
||||
// value: Color to use for pixels with non-finite values.
|
||||
// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
|
||||
// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
|
||||
func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["bad_color"] = value
|
||||
@ -20109,7 +20109,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr {
|
||||
// filter element on that dimension. The dimension order is determined by the
|
||||
// value of `data_format`, see above for details. Dilations in the batch and
|
||||
// depth dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -21281,7 +21281,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr {
|
||||
// element on that dimension. The dimension order is determined by the value of
|
||||
// `data_format`, see above for details. Dilations in the batch and depth
|
||||
// dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -21989,7 +21989,7 @@ func Conv2DDataFormat(value string) Conv2DAttr {
|
||||
// filter element on that dimension. The dimension order is determined by the
|
||||
// value of `data_format`, see above for details. Dilations in the batch and
|
||||
// depth dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func Conv2DDilations(value []int64) Conv2DAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -22185,7 +22185,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy
|
||||
// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
|
||||
//
|
||||
// value: List of dilation values.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -22254,7 +22254,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized
|
||||
// QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
|
||||
//
|
||||
// value: List of dilation values.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi
|
||||
// QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
|
||||
//
|
||||
// value: List of dilation values.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -22428,7 +22428,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D
|
||||
// QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
|
||||
//
|
||||
// value: List of dilation values.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -22602,7 +22602,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann
|
||||
// QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
|
||||
//
|
||||
// value: list of dilation values.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -22979,7 +22979,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr {
|
||||
// filter element on that dimension. The dimension order is determined by the
|
||||
// value of `data_format`, see above for details. Dilations in the batch and
|
||||
// depth dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -25322,7 +25322,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi
|
||||
type Conv3DBackpropFilterAttr func(optionalAttr)
|
||||
|
||||
// Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -25385,7 +25385,7 @@ func Conv3DDataFormat(value string) Conv3DAttr {
|
||||
// filter element on that dimension. The dimension order is determined by the
|
||||
// value of `data_format`, see above for details. Dilations in the batch and
|
||||
// depth dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
func Conv3DDilations(value []int64) Conv3DAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -25636,7 +25636,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN
|
||||
// element on that dimension. The dimension order is determined by the value of
|
||||
// `data_format`, see above for details. Dilations in the batch and depth
|
||||
// dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -26120,7 +26120,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr {
|
||||
// filter element on that dimension. The dimension order is determined by the
|
||||
// value of `data_format`, see above for details. Dilations in the batch and
|
||||
// depth dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -40326,7 +40326,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d
|
||||
// element on that dimension. The dimension order is determined by the value of
|
||||
// `data_format`, see above for details. Dilations in the batch and depth
|
||||
// dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -45852,7 +45852,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr {
|
||||
// element on that dimension. The dimension order is determined by the value of
|
||||
// `data_format`, see above for details. Dilations in the batch and depth
|
||||
// dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -46704,7 +46704,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula
|
||||
type Conv3DBackpropInputAttr func(optionalAttr)
|
||||
|
||||
// Conv3DBackpropInputDilations sets the optional dilations attribute to value.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
|
||||
func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
@ -46775,7 +46775,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr {
|
||||
// element on that dimension. The dimension order is determined by the value of
|
||||
// `data_format`, see above for details. Dilations in the batch and depth
|
||||
// dimensions must be 1.
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
// If not specified, defaults to {i:1 i:1 i:1 i:1}
|
||||
func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dilations"] = value
|
||||
|
@ -274,9 +274,9 @@ class OpNode {
|
||||
return tensorflow::errors::Internal(
|
||||
"Cannot read from invalid tensor index ", input_index);
|
||||
}
|
||||
tensorflow::TensorHandle* handle;
|
||||
TF_RETURN_IF_ERROR(tensorflow::TensorHandle::CreateLocalHandle(
|
||||
buffer_map->GetTensor(input_index), &handle));
|
||||
tensorflow::TensorHandle* handle =
|
||||
tensorflow::TensorHandle::CreateLocalHandle(
|
||||
buffer_map->GetTensor(input_index));
|
||||
op_->MutableInputs()->push_back(handle);
|
||||
} else {
|
||||
// If this is a forwardable tensor, we will remove it from the previous
|
||||
|
@ -237,6 +237,7 @@ cc_library(
|
||||
"//tensorflow/lite/delegates/gpu/common:model_transformer",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/gl:api2",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <thread> // NOLINT(build/c++11)
|
||||
#include <vector>
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/delegates/gpu/api.h"
|
||||
@ -70,6 +71,28 @@ class Delegate {
|
||||
options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default();
|
||||
}
|
||||
|
||||
TfLiteDelegate* tflite_delegate() { return &delegate_; }
|
||||
const TfLiteGpuDelegateOptionsV2& options() const { return options_; }
|
||||
|
||||
private:
|
||||
TfLiteDelegate delegate_ = {
|
||||
.data_ = reinterpret_cast<void*>(this),
|
||||
.Prepare = DelegatePrepare,
|
||||
.CopyFromBufferHandle = nullptr,
|
||||
.CopyToBufferHandle = nullptr,
|
||||
.FreeBufferHandle = nullptr,
|
||||
.flags = kTfLiteDelegateFlagsNone,
|
||||
};
|
||||
|
||||
TfLiteGpuDelegateOptionsV2 options_;
|
||||
};
|
||||
|
||||
// Represent the execution of a subset of nodes on GPU.
|
||||
class DelegateKernel {
|
||||
public:
|
||||
explicit DelegateKernel(const TfLiteGpuDelegateOptionsV2& options)
|
||||
: options_(options) {}
|
||||
|
||||
absl::Status Prepare(TfLiteContext* context,
|
||||
const TfLiteDelegateParams* delegate_params) {
|
||||
thread_id_prepare_ = std::this_thread::get_id();
|
||||
@ -133,20 +156,6 @@ class Delegate {
|
||||
return builder->Build(&runner_);
|
||||
}
|
||||
|
||||
absl::Status SetInputsAndOutputs(TfLiteContext* context) {
|
||||
int i = 0;
|
||||
for (auto index : input_indices_) {
|
||||
RETURN_IF_ERROR(
|
||||
runner_->SetInputObject(i++, GetTensorObject(index, context)));
|
||||
}
|
||||
i = 0;
|
||||
for (auto index : output_indices_) {
|
||||
RETURN_IF_ERROR(
|
||||
runner_->SetOutputObject(i++, GetTensorObject(index, context)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status Invoke(TfLiteContext* context) {
|
||||
if (thread_id_prepare_ != std::this_thread::get_id()) {
|
||||
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
|
||||
@ -162,6 +171,19 @@ class Delegate {
|
||||
return runner_->Run();
|
||||
}
|
||||
|
||||
private:
|
||||
absl::Status SetInputsAndOutputs(TfLiteContext* context) {
|
||||
for (int i = 0; i < input_indices_.size(); ++i) {
|
||||
RETURN_IF_ERROR(runner_->SetInputObject(
|
||||
i, GetTensorObject(input_indices_[i], context)));
|
||||
}
|
||||
for (int i = 0; i < output_indices_.size(); ++i) {
|
||||
RETURN_IF_ERROR(runner_->SetOutputObject(
|
||||
i, GetTensorObject(output_indices_[i], context)));
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
ObjectDef GetObjectDef(int index) const {
|
||||
ObjectDef default_object_def;
|
||||
default_object_def.data_type = DataType::FLOAT32;
|
||||
@ -176,9 +198,6 @@ class Delegate {
|
||||
return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes));
|
||||
}
|
||||
|
||||
TfLiteDelegate* tflite_delegate() { return &delegate_; }
|
||||
|
||||
private:
|
||||
absl::Status InitializeOpenClApi(GraphFloat32* graph,
|
||||
std::unique_ptr<InferenceBuilder>* builder,
|
||||
bool* graph_is_destroyed) {
|
||||
@ -230,28 +249,20 @@ class Delegate {
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
TfLiteDelegate delegate_ = {
|
||||
reinterpret_cast<void*>(this), // .data_
|
||||
DelegatePrepare, // .Prepare
|
||||
nullptr, // .CopyFromBufferHandle
|
||||
nullptr, // .CopyToBufferHandle
|
||||
nullptr, // .FreeBufferHandle
|
||||
kTfLiteDelegateFlagsNone, // .flags
|
||||
};
|
||||
|
||||
TfLiteGpuDelegateOptionsV2 options_;
|
||||
// Shared across all DelegateKernel instances, passed by the Delegate
|
||||
// instance.
|
||||
const TfLiteGpuDelegateOptionsV2& options_;
|
||||
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
|
||||
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
|
||||
std::unique_ptr<InferenceRunner> runner_;
|
||||
std::vector<int64_t> input_indices_;
|
||||
std::vector<int64_t> output_indices_;
|
||||
|
||||
std::thread::id thread_id_prepare_; // thread id used for Prapare()
|
||||
bool enforce_same_thread_ = false; // flag to enforce same thread for Invoke
|
||||
};
|
||||
|
||||
inline Delegate* GetDelegate(TfLiteNode* node) {
|
||||
return reinterpret_cast<Delegate*>(node->user_data);
|
||||
inline DelegateKernel* GetDelegateKernel(TfLiteNode* node) {
|
||||
return reinterpret_cast<DelegateKernel*>(node->user_data);
|
||||
}
|
||||
|
||||
inline Delegate* GetDelegate(TfLiteDelegate* delegate) {
|
||||
@ -267,16 +278,20 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
auto* gpu_delegate = GetDelegate(params->delegate);
|
||||
// Everything below should happen in prepare function call, but TFLite
|
||||
// for whatever reason forbids that.
|
||||
const auto status = gpu_delegate->Prepare(context, params);
|
||||
auto gpu_delegate_kernel =
|
||||
absl::make_unique<DelegateKernel>(gpu_delegate->options());
|
||||
const auto status = gpu_delegate_kernel->Prepare(context, params);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, "TfLiteGpuDelegate Init: %s",
|
||||
std::string(status.message()).c_str());
|
||||
return nullptr;
|
||||
}
|
||||
return gpu_delegate;
|
||||
return gpu_delegate_kernel.release();
|
||||
},
|
||||
// .free
|
||||
[](TfLiteContext*, void* buffer) -> void {},
|
||||
[](TfLiteContext*, void* buffer) -> void {
|
||||
delete reinterpret_cast<DelegateKernel*>(buffer);
|
||||
},
|
||||
// .prepare
|
||||
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
if (!node->user_data) {
|
||||
@ -292,7 +307,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
},
|
||||
// .invoke
|
||||
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
|
||||
const auto status = GetDelegate(node)->Invoke(context);
|
||||
const auto status = GetDelegateKernel(node)->Invoke(context);
|
||||
if (!status.ok()) {
|
||||
context->ReportError(context, "TfLiteGpuDelegate Invoke: %s",
|
||||
std::string(status.message()).c_str());
|
||||
|
@ -119,7 +119,7 @@ class Softmax : public NodeShader {
|
||||
if (z < $depth$) {
|
||||
highp vec4 src = $input_data_0[0, 0, z]$;
|
||||
highp vec4 temp = exp(src) * sum;
|
||||
$output_data_0[0, 0, z]$ = temp;
|
||||
$output_data_0[0, 0, z] = temp$;
|
||||
offset += 32;
|
||||
}
|
||||
s++;
|
||||
|
@ -173,29 +173,6 @@ objc_library(
|
||||
],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "environment_test_lib",
|
||||
testonly = 1,
|
||||
srcs = ["environment_test.mm"],
|
||||
sdk_frameworks = ["XCTest"],
|
||||
deps = [
|
||||
":environment",
|
||||
"//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
ios_unit_test(
|
||||
name = "environment_test",
|
||||
testonly = 1,
|
||||
minimum_os_version = "10.0",
|
||||
runner = tflite_ios_lab_runner("IOS_LATEST"),
|
||||
tags = tf_gpu_tests_tags() + [
|
||||
"notap",
|
||||
"tflite_not_portable_android",
|
||||
],
|
||||
deps = [":environment_test_lib"],
|
||||
)
|
||||
|
||||
objc_library(
|
||||
name = "inference_context",
|
||||
srcs = ["inference_context.mm"],
|
||||
@ -273,7 +250,6 @@ objc_library(
|
||||
srcs = [
|
||||
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compiled_model_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:environment_test.mm",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context_test.mm",
|
||||
],
|
||||
hdrs = [
|
||||
|
@ -88,12 +88,14 @@ std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> SelectConvolutionTransposed(
|
||||
int id, ValueId input_id, ValueId output_id,
|
||||
const ConvolutionTransposedAttributes& attr,
|
||||
const ConvolutionTransposedAttributes& attr, const DeviceInfo& device_info,
|
||||
const metal::RuntimeOptions& options) {
|
||||
if (CheckConvolutionTransposed4x4Support(attr)) {
|
||||
return ConvolutionTransposed4x4(id, input_id, output_id, attr, options);
|
||||
return ConvolutionTransposed4x4(id, input_id, output_id, attr, device_info,
|
||||
options);
|
||||
} else {
|
||||
return ConvolutionTransposed(id, input_id, output_id, attr, options);
|
||||
return ConvolutionTransposed(id, input_id, output_id, attr, device_info,
|
||||
options);
|
||||
}
|
||||
}
|
||||
|
||||
@ -165,6 +167,7 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
|
||||
absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
const std::vector<ValueId>& inputs,
|
||||
const std::vector<ValueId>& outputs,
|
||||
const DeviceInfo& device_info,
|
||||
const RuntimeOptions& options,
|
||||
int* last_node_id, int* last_value_id,
|
||||
std::vector<ComputeTaskDescriptorPtr>* tasks) {
|
||||
@ -219,8 +222,9 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
|
||||
BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
|
||||
(*last_node_id) += 1;
|
||||
auto t1 = ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
|
||||
conv_shape, attr, options);
|
||||
auto t1 =
|
||||
ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
|
||||
conv_shape, attr, device_info, options);
|
||||
tasks->insert(tasks->end(), t1.begin(), t1.end());
|
||||
|
||||
Winograd36To4x4Attributes wino_down_attr;
|
||||
@ -233,7 +237,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
(*last_value_id) += 2;
|
||||
} else {
|
||||
*tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
|
||||
attr, options);
|
||||
attr, device_info, options);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -242,7 +246,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<ConvolutionTransposedAttributes>(
|
||||
node->operation.attributes),
|
||||
options);
|
||||
device_info, options);
|
||||
break;
|
||||
case OperationType::DEPTHWISE_CONVOLUTION:
|
||||
*tasks =
|
||||
@ -255,7 +259,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
*tasks = FullyConnected(
|
||||
node_id, inputs[0], outputs[0],
|
||||
absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
|
||||
options);
|
||||
device_info, options);
|
||||
break;
|
||||
case OperationType::MAX_UNPOOLING_2D:
|
||||
*tasks = MaxUnpooling(
|
||||
@ -388,7 +392,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
|
||||
|
||||
} // namespace
|
||||
|
||||
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
||||
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
|
||||
const RuntimeOptions& options,
|
||||
CompiledModel* compiled_model) {
|
||||
int last_node_id = 0;
|
||||
for (const auto& node : graph.nodes()) {
|
||||
@ -412,7 +417,7 @@ absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
||||
RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
|
||||
if (!custom_status.ok()) {
|
||||
auto primary_status =
|
||||
RegisterPrimaryOps(graph, node, inputs, outputs, options,
|
||||
RegisterPrimaryOps(graph, node, inputs, outputs, device_info, options,
|
||||
&last_node_id, &last_value_id, &tasks);
|
||||
if (!primary_status.ok()) {
|
||||
return absl::UnimplementedError(
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -26,7 +27,8 @@ namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
// Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions.
|
||||
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
|
||||
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
|
||||
const RuntimeOptions& options,
|
||||
CompiledModel* compiled_model);
|
||||
|
||||
} // namespace metal
|
||||
|
@ -16,21 +16,53 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
enum class GpuType {
|
||||
enum class AppleGPU {
|
||||
kUnknown,
|
||||
kA7, // iPhone 5s, iPad Air, iPad Mini 2, iPad Mini 3.
|
||||
kA8, // A8 iPhone 6, A8X iPad Air 2, iPad Mini 4.
|
||||
kA9, // A9 iPhone 6s, iPad (2017), A9X iPad Pro (1st generation).
|
||||
kA10, // iPhone 7, iPad (2018), A10X iPad Pro (2nd generation).
|
||||
kA11, // iPhone 8/X.
|
||||
kA12, // iPhone Xs.
|
||||
kA7,
|
||||
kA8,
|
||||
kA8X,
|
||||
kA9,
|
||||
kA9X,
|
||||
kA10,
|
||||
kA10X,
|
||||
kA11,
|
||||
kA12,
|
||||
kA12X,
|
||||
kA12Z,
|
||||
kA13,
|
||||
};
|
||||
|
||||
GpuType GetGpuType();
|
||||
struct AppleGPUInfo {
|
||||
AppleGPUInfo() = default;
|
||||
explicit AppleGPUInfo(const std::string& device_name);
|
||||
AppleGPU gpu_type;
|
||||
|
||||
bool IsLocalMemoryPreferredOverGlobal() const;
|
||||
|
||||
bool IsBionic() const;
|
||||
|
||||
// floating point rounding mode
|
||||
bool IsRoundToNearestSupported() const;
|
||||
|
||||
int GetComputeUnitsCount() const;
|
||||
};
|
||||
|
||||
struct DeviceInfo {
|
||||
DeviceInfo() = default;
|
||||
explicit DeviceInfo(const std::string& device_name);
|
||||
AppleGPUInfo apple_info;
|
||||
|
||||
// floating point rounding mode
|
||||
bool IsRoundToNearestSupported() const;
|
||||
|
||||
int GetComputeUnitsCount() const;
|
||||
};
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
|
@ -15,82 +15,93 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
|
||||
#import <Metal/Metal.h>
|
||||
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||
#include <map>
|
||||
#include <string>
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace metal {
|
||||
|
||||
GpuType GetGpuType() {
|
||||
int max_feature_set = 0;
|
||||
#if defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0
|
||||
std::vector<std::pair<MTLFeatureSet, int>> features;
|
||||
if (@available(iOS 8.0, *)) {
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v1, 7);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v1, 8);
|
||||
}
|
||||
if (@available(iOS 9.0, *)) {
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v2, 7);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v2, 8);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v1, 9);
|
||||
}
|
||||
if (@available(iOS 10.0, *)) {
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v3, 7);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v3, 8);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v2, 9);
|
||||
}
|
||||
if (@available(iOS 11.0, *)) {
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v4, 8);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v3, 9);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v1, 11);
|
||||
}
|
||||
if (@available(iOS 12.0, *)) {
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v5, 7);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v5, 8);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v4, 9);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v2, 11);
|
||||
features.emplace_back(MTLFeatureSet_iOS_GPUFamily5_v1, 12);
|
||||
}
|
||||
id<MTLDevice> device = GetBestSupportedMetalDevice();
|
||||
for (auto &type : features) {
|
||||
if ([device supportsFeatureSet:type.first]) {
|
||||
max_feature_set = std::max(max_feature_set, type.second);
|
||||
}
|
||||
}
|
||||
#elif defined(__MAC_10_5) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_5
|
||||
std::vector<std::pair<MTLFeatureSet, int>> features;
|
||||
if (@available(macOS 10.15, *)) {
|
||||
features.emplace_back(MTLFeatureSet_macOS_GPUFamily2_v1, 12);
|
||||
}
|
||||
id<MTLDevice> device = GetBestSupportedMetalDevice();
|
||||
for (auto &type : features) {
|
||||
if ([device supportsFeatureSet:type.first]) {
|
||||
max_feature_set = std::max(max_feature_set, type.second);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
switch (max_feature_set) {
|
||||
case 7:
|
||||
return GpuType::kA7;
|
||||
case 8:
|
||||
return GpuType::kA8;
|
||||
case 9:
|
||||
return GpuType::kA9;
|
||||
case 10:
|
||||
return GpuType::kA10;
|
||||
case 11:
|
||||
return GpuType::kA11;
|
||||
case 12:
|
||||
return GpuType::kA12;
|
||||
default:
|
||||
return GpuType::kUnknown;
|
||||
AppleGPUInfo::AppleGPUInfo(const std::string& device_name) {
|
||||
const std::map<std::string, AppleGPU> kMapping = {
|
||||
{"Apple A7 GPU", AppleGPU::kA7},
|
||||
{"Apple A8 GPU", AppleGPU::kA8},
|
||||
{"Apple A8X GPU", AppleGPU::kA8X},
|
||||
{"Apple A9 GPU", AppleGPU::kA9},
|
||||
{"Apple A9X GPU", AppleGPU::kA9X},
|
||||
{"Apple A10 GPU", AppleGPU::kA10},
|
||||
{"Apple A10X GPU", AppleGPU::kA10X},
|
||||
{"Apple A11 GPU", AppleGPU::kA11},
|
||||
{"Apple A12 GPU", AppleGPU::kA12},
|
||||
{"Apple A12X GPU", AppleGPU::kA12X},
|
||||
{"Apple A12Z GPU", AppleGPU::kA12Z},
|
||||
{"Apple A13 GPU", AppleGPU::kA13},
|
||||
};
|
||||
auto it = kMapping.find(device_name);
|
||||
if (it != kMapping.end()) {
|
||||
gpu_type = it->second;
|
||||
} else {
|
||||
gpu_type = AppleGPU::kUnknown;
|
||||
}
|
||||
}
|
||||
|
||||
bool AppleGPUInfo::IsLocalMemoryPreferredOverGlobal() const {
|
||||
return gpu_type == AppleGPU::kA7 ||
|
||||
gpu_type == AppleGPU::kA8 ||
|
||||
gpu_type == AppleGPU::kA8X;
|
||||
}
|
||||
|
||||
bool AppleGPUInfo::IsBionic() const {
|
||||
return gpu_type == AppleGPU::kA11 ||
|
||||
gpu_type == AppleGPU::kA12 ||
|
||||
gpu_type == AppleGPU::kA12X ||
|
||||
gpu_type == AppleGPU::kA12Z ||
|
||||
gpu_type == AppleGPU::kA13;
|
||||
}
|
||||
|
||||
bool AppleGPUInfo::IsRoundToNearestSupported() const {
|
||||
return IsBionic();
|
||||
}
|
||||
|
||||
int AppleGPUInfo::GetComputeUnitsCount() const {
|
||||
switch (gpu_type) {
|
||||
case AppleGPU::kA7:
|
||||
return 4;
|
||||
case AppleGPU::kA8:
|
||||
return 4;
|
||||
case AppleGPU::kA8X:
|
||||
return 8;
|
||||
case AppleGPU::kA9:
|
||||
return 6;
|
||||
case AppleGPU::kA9X:
|
||||
return 12;
|
||||
case AppleGPU::kA10:
|
||||
return 6;
|
||||
case AppleGPU::kA10X:
|
||||
return 12;
|
||||
case AppleGPU::kA11:
|
||||
return 3;
|
||||
case AppleGPU::kA12:
|
||||
return 4;
|
||||
case AppleGPU::kA12X:
|
||||
return 7;
|
||||
case AppleGPU::kA12Z:
|
||||
return 8;
|
||||
case AppleGPU::kA13:
|
||||
return 4;
|
||||
case AppleGPU::kUnknown:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
DeviceInfo::DeviceInfo(const std::string& device_name) : apple_info(device_name) {}
|
||||
|
||||
bool DeviceInfo::IsRoundToNearestSupported() const {
|
||||
return apple_info.IsRoundToNearestSupported();
|
||||
}
|
||||
|
||||
int DeviceInfo::GetComputeUnitsCount() const {
|
||||
return apple_info.GetComputeUnitsCount();
|
||||
}
|
||||
|
||||
} // namespace metal
|
||||
|
@ -1,49 +0,0 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
|
||||
#import <XCTest/XCTest.h>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/metal/common.h"
|
||||
|
||||
using ::tflite::gpu::metal::GetGpuType;
|
||||
|
||||
@interface EnvironmentTest : XCTestCase
|
||||
|
||||
@end
|
||||
|
||||
@implementation EnvironmentTest
|
||||
|
||||
- (void)testCompileTimeOSDetection {
|
||||
#if IOS_VERSION > 0
|
||||
XCTAssertTrue(MACOS_VERSION == 0 && TVOS_VERSION == 0, @"IOS_VERSION: %d", int{IOS_VERSION});
|
||||
#endif
|
||||
#if MACOS_VERSION > 0
|
||||
XCTAssertTrue(IOS_VERSION == 0 && TVOS_VERSION == 0, @"MACOS_VERSION: %d", int{MACOS_VERSION});
|
||||
#endif
|
||||
#if TVOS_VERSION > 0
|
||||
XCTAssertTrue(IOS_VERSION == 0 && MACOS_VERSION == 0, @"TVOS_VERSION: %d", int{TVOS_VERSION});
|
||||
#endif
|
||||
}
|
||||
|
||||
- (void)testGetGpuType {
|
||||
#if (IOS_VERSION > 0) || (TVOS_VERSION > 0)
|
||||
auto gpuType = GetGpuType();
|
||||
XCTAssertTrue(gpuType != GpuType::kUnknown);
|
||||
#endif
|
||||
}
|
||||
|
||||
@end
|
@ -833,6 +833,7 @@ objc_library(
|
||||
"//tensorflow/lite/delegates/gpu/metal:api",
|
||||
"//tensorflow/lite/delegates/gpu/metal:common",
|
||||
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
|
||||
"//tensorflow/lite/delegates/gpu/metal:environment",
|
||||
"//tensorflow/lite/delegates/gpu/metal:inference_context",
|
||||
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
|
||||
"@FP16",
|
||||
|
@ -659,32 +659,19 @@ bool IsKernelYIs1(const Convolution2DAttributes& attr) {
|
||||
attr.padding.appended.h == 0;
|
||||
}
|
||||
|
||||
int GetMaximumPossibleWavesCount(const BHWC& dst_shape, GpuType gpu) {
|
||||
if (gpu == GpuType::kA7 || gpu == GpuType::kA8) {
|
||||
int GetMaximumPossibleWavesCount(const AppleGPUInfo& apple_info,
|
||||
const BHWC& dst_shape) {
|
||||
if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
|
||||
return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1});
|
||||
} else {
|
||||
return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1});
|
||||
}
|
||||
}
|
||||
|
||||
int GetCountOfComputeUnits(GpuType gpu) {
|
||||
if (gpu == GpuType::kA7 || gpu == GpuType::kA8) {
|
||||
return 4;
|
||||
} else if (gpu == GpuType::kA9 || gpu == GpuType::kA10) {
|
||||
return 6;
|
||||
} else if (gpu == GpuType::kA11) {
|
||||
return 3;
|
||||
} else if (gpu == GpuType::kA12) {
|
||||
return 4;
|
||||
} else {
|
||||
// unknown gpu
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
|
||||
int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) {
|
||||
const int max_waves = GetMaximumPossibleWavesCount(dst_shape, gpu);
|
||||
const int cu_count = GetCountOfComputeUnits(gpu);
|
||||
int GetRecommendedBlockSize(const AppleGPUInfo& apple_info,
|
||||
const BHWC& dst_shape) {
|
||||
const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
|
||||
const int cu_count = apple_info.GetComputeUnitsCount();
|
||||
if (max_waves >= cu_count * 64) {
|
||||
return 8;
|
||||
} else if (max_waves >= cu_count * 32) {
|
||||
@ -696,8 +683,9 @@ int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) {
|
||||
}
|
||||
}
|
||||
|
||||
ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
|
||||
const BHWC& dst_shape, GpuType gpu) {
|
||||
ConvParams GetConvParamsForA7A8(const AppleGPUInfo& apple_info,
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC& dst_shape) {
|
||||
const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4);
|
||||
const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||
|
||||
@ -711,7 +699,7 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
|
||||
params.linear_whs = false;
|
||||
params.work_group_launch_order = int3(0, 1, 2);
|
||||
|
||||
int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu);
|
||||
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
|
||||
|
||||
if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
|
||||
params.block_size.z = 4;
|
||||
@ -771,14 +759,14 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
|
||||
return params;
|
||||
}
|
||||
|
||||
ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
|
||||
const BHWC& dst_shape, GpuType gpu) {
|
||||
ConvParams GetConvParamsForA9AndHigher(const AppleGPUInfo& apple_info,
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC& dst_shape) {
|
||||
const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4);
|
||||
const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4);
|
||||
int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu);
|
||||
bool apple_gpu = gpu == GpuType::kA11 || gpu == GpuType::kA12;
|
||||
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
|
||||
int3 block_size = int3(1, 1, 1);
|
||||
if (blk_total_size >= 2 && apple_gpu) {
|
||||
if (blk_total_size >= 2 && apple_info.IsBionic()) {
|
||||
if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) {
|
||||
block_size.x = 2;
|
||||
} else {
|
||||
@ -816,7 +804,7 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
|
||||
params.work_group_size = int3(32, 1, 1);
|
||||
params.work_group_launch_order = int3(0, 1, 2);
|
||||
}
|
||||
float precise_threshold = gpu == GpuType::kA12 ? 1.0f : 1.04f;
|
||||
float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f;
|
||||
float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
|
||||
if (precise_ratio > precise_threshold) {
|
||||
params.linear_wh = false;
|
||||
@ -852,13 +840,13 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
|
||||
return params;
|
||||
}
|
||||
|
||||
ConvParams GetConvParams(const Convolution2DAttributes& attr,
|
||||
ConvParams GetConvParams(const DeviceInfo& device_info,
|
||||
const Convolution2DAttributes& attr,
|
||||
const BHWC& dst_shape) {
|
||||
auto gpu_type = GetGpuType();
|
||||
if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) {
|
||||
return GetConvParamsForA7A8(attr, dst_shape, gpu_type);
|
||||
if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
|
||||
return GetConvParamsForA7A8(device_info.apple_info, attr, dst_shape);
|
||||
} else {
|
||||
return GetConvParamsForA9AndHigher(attr, dst_shape, gpu_type);
|
||||
return GetConvParamsForA9AndHigher(device_info.apple_info, attr, dst_shape);
|
||||
}
|
||||
}
|
||||
|
||||
@ -898,8 +886,9 @@ std::pair<uint3, uint3> GetDispatchSizes(const ConvParams& params,
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
|
||||
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
|
||||
const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) {
|
||||
ConvParams params = GetConvParams(attr, dst_shape);
|
||||
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const metal::RuntimeOptions& options) {
|
||||
ConvParams params = GetConvParams(device_info, attr, dst_shape);
|
||||
|
||||
auto desc = std::make_shared<ComputeTaskDescriptor>();
|
||||
desc->id = id;
|
||||
@ -953,7 +942,8 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
|
||||
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
|
||||
const Convolution2DAttributes& attr, const RuntimeOptions& options) {
|
||||
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const RuntimeOptions& options) {
|
||||
const int dst_slices = IntegralDivideRoundUp(attr.weights.shape.o, 4);
|
||||
ConvParams params;
|
||||
params.work_group_launch_order = int3(2, 0, 1);
|
||||
@ -965,8 +955,7 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
|
||||
params.different_weights_for_height = true;
|
||||
params.x_kernel_is_1 = true;
|
||||
params.y_kernel_is_1 = true;
|
||||
auto gpu_type = GetGpuType();
|
||||
if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) {
|
||||
if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
|
||||
params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
|
||||
params.work_group_size = int3(32, 1, 1);
|
||||
params.block_size = int3(4, 1, 4);
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/delegates/gpu/common/model.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/operations.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
|
||||
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -29,11 +30,13 @@ namespace metal {
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
|
||||
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
|
||||
const Convolution2DAttributes& attr, const RuntimeOptions& options);
|
||||
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const RuntimeOptions& options);
|
||||
|
||||
std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
|
||||
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
|
||||
const Convolution2DAttributes& attr, const RuntimeOptions& options);
|
||||
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
|
||||
const RuntimeOptions& options);
|
||||
|
||||
} // namespace metal
|
||||
} // namespace gpu
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user