Merge branch 'master' into interface_16x8

This commit is contained in:
Elena Zhelezina 2020-03-30 10:45:04 +01:00 committed by GitHub
commit b6c7284053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
343 changed files with 6215 additions and 3872 deletions
tensorflow
c
compiler
core
go/op
lite/delegates

View File

@ -683,7 +683,11 @@ TFE_TensorHandle* TFE_NewTensorHandleFromScalar(TF_DataType data_type,
tensorflow::Tensor tensor(dtype, tensorflow::TensorShape({}));
std::memcpy(tensorflow::TensorCApi::Buffer(tensor)->data(), data, len);
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
status->status = tensorflow::Status::OK();
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
}
namespace {

View File

@ -28,6 +28,7 @@ tf_cuda_library(
"c_api_debug.cc",
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.cc",
"context_interface.h",
"operation_interface.cc",
@ -64,6 +65,7 @@ tf_cuda_library(
"//tensorflow/core/platform:errors",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/lib:traceme",
"@com_google_absl//absl/types:variant",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -97,6 +99,7 @@ filegroup(
srcs = [
"c_api_experimental.h",
"c_api_internal.h",
"c_api_unified_experimental.h",
"context_interface.h",
"dlpack.h",
"operation_interface.h",
@ -112,6 +115,7 @@ tf_cuda_library(
name = "c_api_internal",
srcs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
"context_interface.h",
"operation_interface.h",
"tensor_handle_interface.h",
@ -210,7 +214,6 @@ tf_cuda_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/distributed_runtime/rpc:grpc_server_lib",
"//tensorflow/core/platform:casts",
"@com_google_absl//absl/strings",
],
)
@ -219,8 +222,12 @@ tf_cuda_library(
name = "c_api_experimental",
srcs = [
"c_api_experimental.cc",
"c_api_unified_experimental.cc",
],
hdrs = [
"c_api_experimental.h",
"c_api_unified_experimental.h",
],
hdrs = ["c_api_experimental.h"],
copts = tf_copts() + tfe_xla_copts(),
visibility = ["//visibility:public"],
deps = select({
@ -246,6 +253,7 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:variant",
],
}) + select({
"//tensorflow:with_xla_support": [
@ -297,6 +305,30 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "c_api_unified_experimental_test",
size = "small",
srcs = [
"c_api_unified_experimental_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":c_api",
":c_api_experimental",
":c_api_test_util",
"//tensorflow/c:c_test_util",
"//tensorflow/cc/profiler",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
],
)
tf_cc_test(
name = "custom_device_test",
size = "small",

View File

@ -919,7 +919,10 @@ TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
tensorflow::Tensor tensor;
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
if (!status->status.ok()) return nullptr;
return TFE_TensorHandle::CreateLocalHandle(tensor, status);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(tensor))};
}
void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
@ -1074,10 +1077,12 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
}
return new TFE_TensorHandle{
std::unique_ptr<AbstractTensorHandleInterface>(h->handle->Copy())};
std::unique_ptr<tensorflow::AbstractTensorHandleInterface>(
h->handle->Copy())};
}
AbstractTensorHandleInterface* tensorflow::TensorHandleInterface::Copy() {
tensorflow::AbstractTensorHandleInterface*
tensorflow::TensorHandleInterface::Copy() {
handle_->Ref();
return new TensorHandleInterface(handle_);
}
@ -1166,8 +1171,7 @@ void* TFE_TensorHandleDevicePointer(TFE_TensorHandle* h, TF_Status* status) {
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
tensorflow::TensorHandleFromInterface(h->handle);
if (VariantDeviceIsCustom(handle->device())) {
const tensorflow::Tensor* t;
status->status = handle->Tensor(&t);
@ -1228,19 +1232,17 @@ TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
tensorflow::Tensor t(static_cast<tensorflow::DataType>(dtype),
tensorflow::TensorShape(dimvec), buf);
buf->Unref();
tensorflow::TensorHandle* ret_handle;
if (custom_device == nullptr) {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), device, device, context, &ret_handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(std::move(t), device,
device, context))};
} else {
status->status = tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context, &ret_handle);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(
std::move(t), custom_device, context))};
}
if (!status->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(ret_handle)};
}
// This function will block till the operation that produces `h` has
@ -1254,9 +1256,7 @@ size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle* h,
return 0;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"TFE_TensorHandleDeviceMemorySize may not be called on a remote tensor "
@ -1309,8 +1309,8 @@ void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
num_inputs);
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
handles(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
handles[i].reset(inputs[i]->handle->Copy());
}
@ -1504,8 +1504,8 @@ TF_CAPI_EXPORT extern int TFE_OpGetOutputLength(TFE_Op* op,
void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
TF_Status* status) {
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>> handles(
*num_retvals);
absl::FixedArray<std::unique_ptr<tensorflow::AbstractTensorHandleInterface>>
handles(*num_retvals);
status->status = op->operation->Execute(&handles, num_retvals);
if (!status->status.ok()) {
return;
@ -1529,10 +1529,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
status->status = context->FindCustomDeviceFromName(device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
&handle);
tensorflow::TensorHandleFromInterface(h->handle), &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
@ -1549,10 +1546,7 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
status->status = context->FindCustomDeviceFromName(handle_device_name, &dev);
if (status->status.ok()) {
status->status = dev->CopyTensorFromDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h->handle.get())
->Handle(),
device_name, &handle);
tensorflow::TensorHandleFromInterface(h->handle), device_name, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
@ -1562,9 +1556,8 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
// Handle regular case.
status->status = tensorflow::EagerCopyToDevice(
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle(),
context, &context->Executor(), device, false, &handle);
tensorflow::TensorHandleFromInterface(h->handle), context,
&context->Executor(), device, false, &handle);
if (status->status.ok()) {
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
@ -1622,7 +1615,9 @@ void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t,
TF_Status* status) {
return TFE_TensorHandle::CreateLocalHandle(t, status);
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(
tensorflow::TensorHandle::CreateLocalHandle(t))};
}
void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
@ -1767,9 +1762,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_TensorHandle* result_handle =
device_.copy_tensor_to_device(context_, &tensor_handle, &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
(*result)->Ref();
delete result_handle;
return status.status;
@ -1786,9 +1779,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
TFE_TensorHandle* result_handle = device_.copy_tensor_from_device(
context_, &tensor_handle, target_device_name.c_str(), &status, info_);
if (!status.status.ok()) return status.status;
*result = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
result_handle->handle.get())
->Handle();
*result = tensorflow::TensorHandleFromInterface(result_handle->handle);
(*result)->Ref();
delete result_handle;
return status.status;
@ -1812,9 +1803,7 @@ class CustomDeviceAPI : public tensorflow::CustomDevice {
&attributes, num_retvals, outputs.data(), &status, info_);
if (status.status.ok()) {
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
outputs[i]->handle.get())
->Handle();
retvals[i] = tensorflow::TensorHandleFromInterface(outputs[i]->handle);
retvals[i]->Ref();
delete outputs[i];
}

View File

@ -31,6 +31,7 @@ using tensorflow::string;
void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
const char* raw_device_name, TF_Status* status) {
if (op_to_reset) {
op_to_reset->operation->Clear();
status->status =
op_to_reset->operation->Reset(op_or_function_name, raw_device_name);
} else {

View File

@ -455,6 +455,7 @@ TEST(CAPI, TensorHandleOnDeviceMemory) {
TFE_DeleteTensorHandle(copy_aliased); // Note that this will delete copy.
TFE_DeleteTensorHandle(on_host);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(m_data);
TFE_DeleteTensorHandle(m);
TFE_DeleteContext(ctx);

View File

@ -69,18 +69,7 @@ struct TFE_Context {
};
struct TFE_TensorHandle {
static TFE_TensorHandle* CreateLocalHandle(const class tensorflow::Tensor& t,
TF_Status* s) {
tensorflow::TensorHandle* handle;
s->status = tensorflow::TensorHandle::CreateLocalHandle(t, &handle);
if (!s->status.ok()) {
return nullptr;
}
return new TFE_TensorHandle{
std::make_unique<tensorflow::TensorHandleInterface>(handle)};
}
std::unique_ptr<AbstractTensorHandleInterface> handle;
std::unique_ptr<tensorflow::AbstractTensorHandleInterface> handle;
};
struct TFE_TensorDebugInfo {
@ -92,7 +81,7 @@ struct TFE_TensorDebugInfo {
};
struct TFE_Op {
std::unique_ptr<AbstractOperationInterface> operation;
std::unique_ptr<tensorflow::AbstractOperationInterface> operation;
};
struct TFE_MonitoringCounterCell {

View File

@ -184,9 +184,7 @@ void TestRemoteExecuteSilentCopies(bool async, bool remote) {
// TODO(gjn): Add support for waiting on async local mirrors
if (!async) {
auto remote_arg = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
h1_task2->handle.get())
->Handle();
auto remote_arg = tensorflow::TensorHandleFromInterface(h1_task2->handle);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());
// The input handles should never change since they have been mirrored.

View File

@ -409,13 +409,8 @@ void TensorHandleSilentCopy(bool async,
ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
// Validate if the input was replaced with a different TensorHandle
auto arg0 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hcpu->handle.get())
->Handle();
auto arg1 = tensorflow::down_cast<tensorflow::TensorHandleInterface*>(
hgpu->handle.get())
->Handle();
auto arg0 = tensorflow::TensorHandleFromInterface(hcpu->handle);
auto arg1 = tensorflow::TensorHandleFromInterface(hgpu->handle);
auto op = tensorflow::down_cast<tensorflow::OperationInterface*>(
matmul->operation.get());

View File

@ -0,0 +1,261 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/lib/monitoring/gauge.h"
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
using tensorflow::string;
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
typedef void (*ExecuteOperation)(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs,
TF_OutputList* o, TF_ExecutionContext* ctx,
TF_Status* s);
struct TF_ExecutionContext {
explicit TF_ExecutionContext() {}
absl::variant<TFE_Context*, TF_GraphContext*> ctx;
ExecuteOperation execution_callback;
};
struct TF_AbstractTensor {
absl::variant<TFE_TensorHandle*, TF_GraphTensor*> t;
};
struct TF_AbstractOp {
string op_type;
string op_name;
};
TF_ExecutionContext* TF_NewExecutionContext() {
return new TF_ExecutionContext();
}
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete c; }
TF_AbstractOp* TF_NewAbstractOp() {
TF_AbstractOp* op = new TF_AbstractOp;
return op;
}
void TF_DeleteAbstractOp(TF_AbstractOp* op) { delete op; }
TF_AbstractTensor* TF_NewAbstractTensor() {
TF_AbstractTensor* t = new TF_AbstractTensor;
return t;
}
void TF_DeleteAbstractTensor(TF_AbstractTensor* t) { delete t; }
struct TF_GraphContext {
TF_Graph* graph;
// TODO(srbs): Handle captures.
};
TF_GraphContext* TF_NewGraphContext(TF_Graph* g) {
auto ctx = new TF_GraphContext;
ctx->graph = g;
return ctx;
}
void TF_DeleteGraphContext(TF_GraphContext* ctx) { delete ctx; }
struct TF_GraphTensor {
TF_Output output;
TF_GraphContext* ctx;
};
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* ctx, TF_Output output,
TF_Status* s) {
TF_GraphTensor* t = new TF_GraphTensor;
t->output = output;
t->ctx = ctx;
return t;
}
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s) {
return t->output;
}
void TF_DeleteGraphTensor(TF_GraphTensor* t) { delete t; }
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s) {
at->t = t;
}
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TFE_TensorHandle*>(at->t)) {
string msg = absl::StrCat("Not an eager tensor handle.",
reinterpret_cast<uintptr_t>(at));
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TFE_TensorHandle*>(at->t);
}
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s) {
at->t = t;
}
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s) {
if (!absl::holds_alternative<TF_GraphTensor*>(at->t)) {
string msg = absl::StrCat("Not an graph tensor handle.");
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return nullptr;
}
return absl::get<TF_GraphTensor*>(at->t);
}
bool IsEagerTensor(const TF_AbstractTensor* const t) {
return absl::holds_alternative<TFE_TensorHandle*>(t->t);
}
struct TF_OutputList {
std::vector<TF_AbstractTensor*> outputs;
int expected_num_outputs = -1;
};
TF_OutputList* TF_NewOutputList() { return new TF_OutputList; }
void TF_DeleteOutputList(TF_OutputList* o) { delete o; }
void TF_OutputListSetNumOutputs(TF_OutputList* o, int num_outputs,
TF_Status* s) {
o->expected_num_outputs = num_outputs;
}
int TF_OutputListNumOutputs(TF_OutputList* o) { return o->outputs.size(); }
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i) {
return o->outputs[i];
}
void ExecuteOperationEager(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
auto* tfe_op =
TFE_NewOp(absl::get<TFE_Context*>(ctx->ctx), op->op_type.c_str(), s);
if (TF_GetCode(s) != TF_OK) return;
for (int i = 0; i < num_inputs; ++i) {
if (!IsEagerTensor(inputs[i])) {
TF_SetStatus(s, TF_INVALID_ARGUMENT, "Not an eager tensor.");
return;
}
TFE_OpAddInput(tfe_op, absl::get<TFE_TensorHandle*>(inputs[i]->t), s);
if (TF_GetCode(s) != TF_OK) return;
}
if (o->expected_num_outputs == -1) {
string msg =
"The number of outputs must be provided in eager mode. Use "
"TF_OutputListSetNumOutputs.";
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
return;
}
tensorflow::gtl::InlinedVector<TFE_TensorHandle*, 2> retvals;
int num_retvals = o->expected_num_outputs;
retvals.resize(num_retvals);
TFE_Execute(tfe_op, retvals.data(), &num_retvals, s);
TFE_DeleteOp(tfe_op);
if (TF_GetCode(s) != TF_OK) {
return;
}
o->outputs.clear();
o->outputs.reserve(num_retvals);
for (int i = 0; i < num_retvals; ++i) {
auto* t = TF_NewAbstractTensor();
t->t = retvals[i];
o->outputs.push_back(t);
}
}
TF_GraphContext* GetGraphContext(TF_AbstractTensor const* t) {
return absl::get<TF_GraphTensor*>(t->t)->ctx;
}
void ExecuteOperationGraph(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
TF_GraphContext* graph_ctx = absl::get<TF_GraphContext*>(ctx->ctx);
TF_Graph* g = graph_ctx->graph;
auto* tf_opdesc =
TF_NewOperation(g, op->op_type.c_str(), op->op_name.c_str());
for (int i = 0; i < num_inputs; ++i) {
auto* input = inputs[i];
if (IsEagerTensor(input)) {
TF_SetStatus(s, TF_INVALID_ARGUMENT,
"Capturing eager tensors is not supported yet.");
return;
} else {
if (GetGraphContext(input) != graph_ctx) {
TF_SetStatus(
s, TF_INVALID_ARGUMENT,
"Capturing tensors from other graphs is not supported yet.");
return;
}
TF_AddInput(tf_opdesc, absl::get<TF_GraphTensor*>(input->t)->output);
}
}
auto* operation = TF_FinishOperation(tf_opdesc, s);
if (TF_GetCode(s) != TF_OK) return;
int num_outputs = TF_OperationNumOutputs(operation);
o->outputs.clear();
o->outputs.reserve(num_outputs);
for (int i = 0; i < num_outputs; ++i) {
auto* t = TF_NewAbstractTensor();
TF_GraphTensor* output_t = TF_NewGraphTensor(graph_ctx, {operation, i}, s);
if (TF_GetCode(s) != TF_OK) {
return;
}
t->t = output_t;
o->outputs.push_back(t);
}
}
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context,
TF_Status* s) {
context->ctx = eager_context;
context->execution_callback = &ExecuteOperationEager;
}
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status* s) {
context->ctx = graph_context;
context->execution_callback = &ExecuteOperationGraph;
}
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s) {
op->op_type = op_type;
}
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s) {
op->op_name = op_name;
}
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s) {
ctx->execution_callback(op, num_inputs, inputs, o, ctx, s);
}

View File

@ -0,0 +1,119 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#define TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#ifdef __cplusplus
extern "C" {
#endif
// =============================================================================
// Unified Execution APIs for Eager and tracing backends.
// =============================================================================
// -----------------------------------------------------------------------------
// Core APIs
// -----------------------------------------------------------------------------
// A TF_ExecutionContext stores knowledge about how to execute an operation.
// E.g. it could know whether we're in eager mode or in graph mode, keeps track
// of gradient tapes, etc.
typedef struct TF_ExecutionContext TF_ExecutionContext;
// A TF_AbstractTensor is an input to an operation. E.g. it could be a union
// type of eager and graph tensors.
typedef struct TF_AbstractTensor TF_AbstractTensor;
// A TF_AbstractOp is the metadata we need to execute an operation. E.g. this
// could contain the op type and other attributes.
typedef struct TF_AbstractOp TF_AbstractOp;
TF_ExecutionContext* TF_NewExecutionContext();
void TF_DeleteExecutionContext(TF_ExecutionContext*);
TF_AbstractOp* TF_NewAbstractOp();
void TF_DeleteAbstractOp(TF_AbstractOp*);
TF_AbstractTensor* TF_NewAbstractTensor();
void TF_DeleteAbstractTensor(TF_AbstractTensor*);
// -----------------------------------------------------------------------------
// APIs for Eager and graph modes
// -----------------------------------------------------------------------------
// Keeps track of the current graph and other state e.g. captures etc.
typedef struct TF_GraphContext TF_GraphContext;
TF_GraphContext* TF_NewGraphContext(TF_Graph*);
void TF_DeleteGraphContext(TF_GraphContext*);
// `eager_context` must outlive `context`.
void TF_ExecutionContextSetEagerContext(TF_ExecutionContext* context,
TFE_Context* eager_context, TF_Status*);
// `graph_context` must outlive `context`.
void TF_ExecutionContextSetGraphContext(TF_ExecutionContext* context,
TF_GraphContext* graph_context,
TF_Status*);
// TODO(srbs): Add APIs for specifying attrs etc.
// `op_type` must outlive `op`.
void TF_AbstractOpSetOpType(TF_AbstractOp* op, const char* const op_type,
TF_Status* s);
// `op_name` must outlive `op`.
void TF_AbstractOpSetOpName(TF_AbstractOp* op, const char* const op_name,
TF_Status* s);
// Wrapper for TF_Output but contains a pointer to TF_GraphContext as well.
typedef struct TF_GraphTensor TF_GraphTensor;
TF_GraphTensor* TF_NewGraphTensor(TF_GraphContext* c, TF_Output t,
TF_Status* s);
TF_Output TF_GraphTensorToOutput(const TF_GraphTensor* const t, TF_Status* s);
void TF_DeleteGraphTensor(TF_GraphTensor* t);
// `t` must outlive `at`.
void TF_AbstractTensorSetEagerTensor(TF_AbstractTensor* at, TFE_TensorHandle* t,
TF_Status* s);
TFE_TensorHandle* TF_AbstractTensorGetEagerTensor(TF_AbstractTensor* at,
TF_Status* s);
// `t` must outlive `at`.
void TF_AbstractTensorSetGraphTensor(TF_AbstractTensor* at, TF_GraphTensor* t,
TF_Status* s);
TF_GraphTensor* TF_AbstractTensorGetGraphTensor(TF_AbstractTensor* at,
TF_Status* s);
// TF_OutputList just lets us not specify the number of outputs of an operation
// beforehand. This forces a memory allocation in the runtime, which is bad, but
// it allows for generic code.
typedef struct TF_OutputList TF_OutputList;
TF_OutputList* TF_NewOutputList();
void TF_DeleteOutputList(TF_OutputList* o);
void TF_OutputListSetNumOutputs(TF_OutputList* o, int, TF_Status*);
int TF_OutputListNumOutputs(TF_OutputList* o);
TF_AbstractTensor* TF_OutputListGet(TF_OutputList* o, int i);
// TF_ExecuteOperation will, if in eager mode, execute, if in graph mode, maybe
// capture some inputs and then add a node in the graph, and after
// execution/node creation it'll go and record things that happened in any tape
// which happens to be active.
void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
TF_AbstractTensor* const* inputs, TF_OutputList* o,
TF_ExecutionContext* ctx, TF_Status* s);
#ifdef __cplusplus
} /* end extern "C" */
#endif
#endif // TENSORFLOW_C_EAGER_C_API_UNIFIED_EXPERIMENTAL_H_

View File

@ -0,0 +1,204 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include <string.h>
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/cc/profiler/profiler.h"
#include "tensorflow/core/lib/monitoring/collection_registry.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
using tensorflow::string;
namespace tensorflow {
namespace {
TEST(UnifedCAPI, TestBasicEager) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* t = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* at = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(at, t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {at, at};
TF_OutputList* o = TF_NewOutputList();
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(at);
TFE_DeleteTensorHandle(t);
// Verify the results.
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* result_tensor = TFE_TensorHandleResolve(result_t, status.get());
float* result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 4.0);
TF_DeleteTensor(result_tensor);
TF_DeleteAbstractTensor(result);
TFE_DeleteTensorHandle(result_t);
TF_DeleteOutputList(o);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
}
TEST(UnifedCAPI, TestBasicGraph) {
TF_ExecutionContext* ctx = TF_NewExecutionContext();
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
// Enter a graph context.
TF_Graph* g = TF_NewGraph();
TF_GraphContext* graph_context = TF_NewGraphContext(g);
TF_ExecutionContextSetGraphContext(ctx, graph_context, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Add a placeholder to the graph.
auto* placeholder_op = TF_NewOperation(g, "Placeholder", "Placeholder");
TF_SetAttrType(placeholder_op, "dtype", TF_FLOAT);
auto* operation = TF_FinishOperation(placeholder_op, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output placeholder_t = {operation, 0};
TF_GraphTensor* graph_t =
TF_NewGraphTensor(graph_context, placeholder_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractTensor* t = TF_NewAbstractTensor();
TF_AbstractTensorSetGraphTensor(t, graph_t, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
auto* op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(op, "Add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOpSetOpName(op, "my_add", status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build inputs and outputs.
TF_AbstractTensor* inputs[2] = {t, t};
TF_OutputList* o = TF_NewOutputList();
// Execute.
TF_ExecuteOperation(op, 2, inputs, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Clean up operation and inputs.
TF_DeleteAbstractOp(op);
TF_DeleteAbstractTensor(t);
TF_DeleteGraphTensor(graph_t);
TF_AbstractTensor* result = TF_OutputListGet(o, 0);
TF_GraphTensor* result_graph_tensor =
TF_AbstractTensorGetGraphTensor(result, status.get());
TF_DeleteAbstractTensor(result);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Output result_output =
TF_GraphTensorToOutput(result_graph_tensor, status.get());
TF_DeleteGraphTensor(result_graph_tensor);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
string fn_name = "double";
TF_Function* f = TF_GraphToFunction(
g, fn_name.c_str(), 0, -1, nullptr, 1, &placeholder_t, 1, &result_output,
nullptr, nullptr, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an eager context to run the function.
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* eager_ctx = TFE_NewContext(opts, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_DeleteContextOptions(opts);
// Build the abstract op to run the function.
TFE_ContextAddFunction(eager_ctx, f, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_AbstractOp* fn_op = TF_NewAbstractOp();
TF_AbstractOpSetOpType(fn_op, fn_name.c_str(), status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract input tensor.
TFE_TensorHandle* input_eager = TestScalarTensorHandle(2.0f);
TF_AbstractTensor* input_t = TF_NewAbstractTensor();
TF_AbstractTensorSetEagerTensor(input_t, input_eager, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Enter the eager context.
TF_ExecutionContextSetEagerContext(ctx, eager_ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_OutputListSetNumOutputs(o, 1, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_ExecuteOperation(fn_op, 1, &input_t, o, ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
ASSERT_EQ(1, TF_OutputListNumOutputs(o));
TF_AbstractTensor* final_result = TF_OutputListGet(o, 0);
TFE_TensorHandle* final =
TF_AbstractTensorGetEagerTensor(final_result, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_Tensor* f_t = TFE_TensorHandleResolve(final, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
float* f_value = static_cast<float*>(TF_TensorData(f_t));
ASSERT_EQ(*f_value, 4.0);
TF_DeleteOutputList(o);
TF_DeleteAbstractOp(fn_op);
TF_DeleteAbstractTensor(input_t);
TFE_DeleteTensorHandle(input_eager);
TF_DeleteAbstractTensor(final_result);
TFE_DeleteTensorHandle(final);
TF_DeleteTensor(f_t);
TF_DeleteFunction(f);
TF_DeleteGraphContext(graph_context);
TF_DeleteGraph(g);
TFE_DeleteContext(eager_ctx);
TF_DeleteExecutionContext(ctx);
}
} // namespace
} // namespace tensorflow

View File

@ -121,20 +121,13 @@ std::unique_ptr<AbstractTensorInterface> ContextInterface::CreateBoolTensor(
Tensor(DT_BOOL, TensorShape(dim_sizes)));
}
Status ContextInterface::CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t,
std::unique_ptr<AbstractTensorHandleInterface>* h) {
std::unique_ptr<AbstractTensorHandleInterface>
ContextInterface::CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) {
Tensor tensor = tensorflow::down_cast<TensorInterface*>(t.get())->Tensor();
tensorflow::TensorHandle* handle = nullptr;
auto status =
return std::make_unique<TensorHandleInterface>(
TensorHandle::CreateLocalHandle(std::move(tensor), /*d=*/ctx_->HostCPU(),
/*op_device=*/nullptr, ctx_, &handle);
if (!status.ok()) {
return status;
}
*h = std::make_unique<TensorHandleInterface>(handle);
return status;
/*op_device=*/nullptr, ctx_));
}
std::unique_ptr<AbstractOperationInterface>

View File

@ -75,16 +75,14 @@ class AbstractContextInterface {
absl::Span<const int64> dim_sizes) = 0;
// Create a handle to wrap and manage a Tensor
virtual tensorflow::Status CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t,
std::unique_ptr<AbstractTensorHandleInterface>* handle) = 0;
virtual std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) = 0;
// Create an operation to perform op execution
virtual std::unique_ptr<AbstractOperationInterface> CreateOperation() = 0;
// List attributes of available devices
virtual void ListDevices(
std::vector<tensorflow::DeviceAttributes>* devices) = 0;
virtual void ListDevices(std::vector<DeviceAttributes>* devices) = 0;
};
// TODO(gjn): Try to move these all to EagerContext and make it implement
@ -133,12 +131,11 @@ class ContextInterface : public AbstractContextInterface {
std::unique_ptr<AbstractTensorInterface> CreateBoolTensor(
absl::Span<const int64> dim_sizes) override;
tensorflow::Status CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t,
std::unique_ptr<AbstractTensorHandleInterface>* h) override;
std::unique_ptr<AbstractTensorHandleInterface> CreateLocalHandle(
const std::unique_ptr<AbstractTensorInterface> t) override;
std::unique_ptr<AbstractOperationInterface> CreateOperation() override;
void ListDevices(std::vector<tensorflow::DeviceAttributes>* devices) override;
void ListDevices(std::vector<DeviceAttributes>* devices) override;
// For runtime specific APIs, provide ability to get the underlying context.
EagerContext* Context() const { return ctx_; }
@ -149,7 +146,7 @@ class ContextInterface : public AbstractContextInterface {
inline EagerContext* ContextFromInterface(
const std::unique_ptr<AbstractContextInterface>& context) {
return down_cast<tensorflow::ContextInterface*>(context.get())->Context();
return down_cast<ContextInterface*>(context.get())->Context();
}
} // namespace tensorflow

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
@ -47,9 +46,7 @@ const Tensor* GetTensorFromHandle(TFE_TensorHandle* h, TF_Status* status) {
return nullptr;
}
tensorflow::TensorHandle* handle =
tensorflow::down_cast<tensorflow::TensorHandleInterface*>(h->handle.get())
->Handle();
tensorflow::TensorHandleFromInterface(h->handle);
if (handle->IsRemote()) {
status->status = tensorflow::errors::InvalidArgument(
"DLPack doesn't support remote tensor");

View File

@ -266,8 +266,7 @@ Status OperationInterface::OutputLength(const char* output_name, int* length) {
Status OperationInterface::AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
TensorHandle* h = TensorHandleFromInterface(input);
operation_.AddInput(h);
return operation_.MaybeInferSingleInputAttrs(h);
}
@ -276,8 +275,7 @@ Status OperationInterface::AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) {
for (auto& input : inputs) {
TensorHandle* h =
tensorflow::down_cast<TensorHandleInterface*>(input.get())->Handle();
TensorHandle* h = TensorHandleFromInterface(input);
operation_.AddInput(h);
}
return operation_.InferInputListAttrs(inputs.size());

View File

@ -23,91 +23,75 @@ limitations under the License.
#include "tensorflow/c/eager/tensor_handle_interface.h"
#include "tensorflow/core/common_runtime/eager/eager_operation.h"
namespace tensorflow {
// Abstract interface to an operation.
class AbstractOperationInterface {
public:
virtual ~AbstractOperationInterface() {}
virtual void Clear() = 0;
virtual tensorflow::Status Reset(const char* op,
const char* raw_device_name) = 0;
virtual Status Reset(const char* op, const char* raw_device_name) = 0;
virtual const tensorflow::string& Name() const = 0;
virtual const tensorflow::string& DeviceName() const = 0;
virtual tensorflow::Status SetDeviceName(const char* name) = 0;
virtual const string& Name() const = 0;
virtual const string& DeviceName() const = 0;
virtual Status SetDeviceName(const char* name) = 0;
virtual tensorflow::Status AddInput(
virtual Status AddInput(
const std::unique_ptr<AbstractTensorHandleInterface>& input) = 0;
virtual tensorflow::Status AddInputList(
virtual Status AddInputList(
const absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>&
inputs) = 0;
virtual tensorflow::Status Execute(
virtual Status Execute(
absl::FixedArray<std::unique_ptr<AbstractTensorHandleInterface>>* retvals,
int* num_retvals) = 0;
virtual const tensorflow::OpDef* OpDef() const = 0;
virtual tensorflow::Status SetAttrString(const char* attr_name,
const char* data, size_t length) = 0;
virtual tensorflow::Status SetAttrInt(const char* attr_name,
int64_t value) = 0;
virtual tensorflow::Status SetAttrFloat(const char* attr_name,
float value) = 0;
virtual tensorflow::Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual tensorflow::Status SetAttrType(const char* attr_name,
TF_DataType value) = 0;
virtual tensorflow::Status SetAttrShape(const char* attr_name,
const int64_t* dims,
const int num_dims) = 0;
virtual tensorflow::Status SetAttrFunction(
virtual Status SetAttrString(const char* attr_name, const char* data,
size_t length) = 0;
virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
virtual Status SetAttrType(const char* attr_name, TF_DataType value) = 0;
virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
const int num_dims) = 0;
virtual Status SetAttrFunction(
const char* attr_name,
const std::unique_ptr<AbstractOperationInterface>& value) = 0;
virtual tensorflow::Status SetAttrFunctionName(const char* attr_name,
const char* value,
size_t length) = 0;
virtual tensorflow::Status SetAttrTensor(const char* attr_name,
TF_Tensor* tensor) = 0;
virtual tensorflow::Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths,
int num_values) = 0;
virtual tensorflow::Status SetAttrFloatList(const char* attr_name,
const float* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrIntList(const char* attr_name,
const int64_t* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual tensorflow::Status SetAttrShapeList(const char* attr_name,
const int64_t** dims,
const int* num_dims,
int num_values) = 0;
virtual tensorflow::Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value,
int num_values) = 0;
virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
size_t length) = 0;
virtual Status SetAttrTensor(const char* attr_name, TF_Tensor* tensor) = 0;
virtual Status SetAttrStringList(const char* attr_name,
const void* const* values,
const size_t* lengths, int num_values) = 0;
virtual Status SetAttrFloatList(const char* attr_name, const float* values,
int num_values) = 0;
virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
int num_values) = 0;
virtual Status SetAttrTypeList(const char* attr_name,
const TF_DataType* values, int num_values) = 0;
virtual Status SetAttrBoolList(const char* attr_name,
const unsigned char* values,
int num_values) = 0;
virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
const int* num_dims, int num_values) = 0;
virtual Status SetAttrFunctionList(const char* attr_name,
const TFE_Op** value, int num_values) = 0;
virtual tensorflow::Status InputLength(const char* input_name,
int* length) = 0;
virtual tensorflow::Status OutputLength(const char* output_name,
int* length) = 0;
virtual Status InputLength(const char* input_name, int* length) = 0;
virtual Status OutputLength(const char* output_name, int* length) = 0;
// Experimental
virtual tensorflow::Status SetUseXla(bool enable) {
return tensorflow::errors::Unimplemented("SetUseXla not implemented");
virtual Status SetUseXla(bool enable) {
return errors::Unimplemented("SetUseXla not implemented");
}
virtual tensorflow::Status SetCancellationManager(
virtual Status SetCancellationManager(
TFE_CancellationManager* cancellation_manager) {
return tensorflow::errors::Unimplemented(
"SetCancellationManager not implemented");
return errors::Unimplemented("SetCancellationManager not implemented");
}
};
namespace tensorflow {
class OpDef;
class OperationInterface : public AbstractOperationInterface {
@ -173,8 +157,8 @@ class OperationInterface : public AbstractOperationInterface {
TFE_CancellationManager* cancellation_manager) override;
// TODO(gjn): Remove once TFE_InferShapes is removed
const tensorflow::AttrBuilder& Attrs() const { return operation_.Attrs(); }
tensorflow::AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const AttrBuilder& Attrs() const { return operation_.Attrs(); }
AttrBuilder* MutableAttrs() { return operation_.MutableAttrs(); }
const TensorHandle* GetInput(int i) const { return operation_.Inputs()[i]; }

View File

@ -19,6 +19,9 @@ limitations under the License.
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/platform/casts.h"
namespace tensorflow {
// Abstract interface to a TensorHandle.
//
@ -34,24 +37,24 @@ class AbstractTensorHandleInterface {
virtual ~AbstractTensorHandleInterface() {}
// Check if the handle is in a valid initialized state.
virtual bool IsValid(tensorflow::Status* status) const = 0;
virtual bool IsValid(Status* status) const = 0;
// Returns tensor dtype.
virtual TF_DataType DataType() const = 0;
// Returns number of dimensions.
virtual int NumDims(tensorflow::Status* status) const = 0;
virtual int NumDims(Status* status) const = 0;
// Returns number of elements across all dimensions.
virtual int64_t NumElements(tensorflow::Status* status) const = 0;
virtual int64_t NumElements(Status* status) const = 0;
// Returns size of specified dimension
virtual int64_t Dim(int dim_index, tensorflow::Status* status) const = 0;
virtual int64_t Dim(int dim_index, Status* status) const = 0;
// Returns the device which created the handle.
virtual const char* DeviceName(tensorflow::Status* status) const = 0;
virtual const char* DeviceName(Status* status) const = 0;
// Returns the device where the tensor was placed.
virtual const char* BackingDeviceName(tensorflow::Status* status) const = 0;
virtual const char* BackingDeviceName(Status* status) const = 0;
// Returns a tensor for the handle. If tensor is remote, it will be copied.
virtual TF_Tensor* Resolve(tensorflow::Status* status) = 0;
virtual TF_Tensor* Resolve(Status* status) = 0;
// Returns debug information about the tensor.
virtual TFE_TensorDebugInfo* TensorDebugInfo(tensorflow::Status* status) = 0;
virtual TFE_TensorDebugInfo* TensorDebugInfo(Status* status) = 0;
// Return a copy of the handle.
virtual AbstractTensorHandleInterface* Copy() = 0;
@ -65,8 +68,9 @@ class AbstractTensorHandleInterface {
virtual void EnableImplicitMirroring() = 0;
};
namespace tensorflow {
// TODO(gjn): Try to move these all to TensorHandle and make it implement
// AbstractTensorHandleInterface. Currently, this is not so straightforward
// because of various BUILD file dependencies.
class TensorHandleInterface : public AbstractTensorHandleInterface {
public:
explicit TensorHandleInterface(TensorHandle* h) : handle_(h) {}
@ -87,14 +91,18 @@ class TensorHandleInterface : public AbstractTensorHandleInterface {
void EnableImplicitMirroring() override;
// TODO(gjn): This is not a very generic interface, but is needed for specific
// use cases.
// For runtime specific APIs, provide ability to get the underlying handle.
TensorHandle* Handle() { return handle_; }
private:
TensorHandle* handle_;
};
inline TensorHandle* TensorHandleFromInterface(
const std::unique_ptr<AbstractTensorHandleInterface>& handle) {
return down_cast<TensorHandleInterface*>(handle.get())->Handle();
}
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_TENSOR_HANDLE_INTERFACE_H_

View File

@ -42,6 +42,20 @@ class TFLiteCostEstimator<AveragePool2DOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.concatenation
template <>
class TFLiteCostEstimator<ConcatenationOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieliu): We probably need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.conv_2d
template <>
class TFLiteCostEstimator<Conv2DOp, hardware::GPU> {
@ -69,6 +83,20 @@ class TFLiteCostEstimator<DepthwiseConv2DOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.fully_connected
template <>
class TFLiteCostEstimator<FullyConnectedOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
// TODO(renjieliu): we need to check for dynamic weights.
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.logistic
template <>
class TFLiteCostEstimator<LogisticOp, hardware::GPU> {
@ -95,6 +123,19 @@ class TFLiteCostEstimator<MaxPool2DOp, hardware::GPU> {
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mirror_pad
template <>
class TFLiteCostEstimator<MirrorPadOp, hardware::GPU> {
public:
static double GetCost(mlir::Operation* op) {
llvm::errs() << "No defined cost function for op: "
<< op->getName().getStringRef().str();
return 0.0;
}
static bool IsSupported(mlir::Operation* op) { return true; }
};
// tfl.mul
template <>
class TFLiteCostEstimator<MulOp, hardware::GPU> {

View File

@ -579,7 +579,8 @@ def TFL_ConcatenationOp : TFL_Op<"concatenation",
NoSideEffect,
PredOpTrait<"values and output must have same element type",
TFL_TCresVTEtIsSameAsOp<0, 0>>,
SameOperandsAndResultsScale
SameOperandsAndResultsScale,
TFL_GpuTargetOp
]> {
let summary = "Concatenation operator";
@ -754,7 +755,8 @@ def TFL_FullyConnectedOp : TFL_Op<"fully_connected", [
NoSideEffect, AccumulatorUniformScale<2, 0, 1>,
TFL_ChannelDimIndexInterface,
AffineOpCoefficient<-1, 1>,
TFL_SparseOp]> {
TFL_SparseOp,
TFL_GpuTargetOp]> {
let summary = "Fully connected op";
let arguments = (ins
@ -2912,7 +2914,7 @@ def TFL_CastOp : TFL_Op<"cast", [
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
NoSideEffect, TFL_OperandHasRank<1, 2>]> {
NoSideEffect, TFL_OperandHasRank<1, 2>, TFL_GpuTargetOp]> {
let summary = "MirrorPad Operator. Pads a tensor with mirrored values.";
let description = [{

View File

@ -79,27 +79,12 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_config.quant_specs.serialized_quant_stats));
}
// Note:
// We need to fuse composite ops before LowerStaticTensorList pass.
// The tensorflow list is not supported right now by that pass.
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
if (pass_config.lower_tensor_list_ops) {
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// The conversion pipeline has to follow the following orders:
// 1) Try to convert ophint nodes if present first like ophint lstm.
// 2) Saved model related optimization like decompose resource ops
// 3) Convert composite functions like lstm/rnns, along with proper function
// inlining & dce.
// 4) Lower static tensor list pass.
// The ophint extractions happen before lots of other passes:
// The assumption of ophint-extraction is each ophinted region is a black-box
@ -122,6 +107,28 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFDevice::CreateDecomposeResourceOpsPass());
// Note:
// We need to fuse composite ops before LowerStaticTensorList pass.
// The tensorflow list is not supported right now by that pass.
// Enable fusing composite ops that can be lowered to built-in TFLite ops.
if (pass_config.emit_builtin_tflite_ops) {
pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
}
// This pass marks non-exported functions as symbol visibility 'private'
// those deemed read-only as immutable.
pass_manager->addPass(
mlir::tf_saved_model::
CreateMarkFunctionVisibilityUsingSavedModelLinkagePass());
pass_manager->addPass(mlir::createInlinerPass());
pass_manager->addPass(mlir::createSymbolDCEPass());
if (pass_config.lower_tensor_list_ops) {
// TODO(haoliang): Add this pass by default.
pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass());
}
// This pass does resource analysis of saved model global tensors and marks
// those deemed read-only as immutable.
pass_manager->addPass(

View File

@ -197,7 +197,9 @@ Status MlirV1CompatGraphOptimizationPass::Run(
RegisterDialects();
mlir::MLIRContext context;
GraphImportConfig import_config;
import_config.upgrade_legacy = true;
// TODO(b/150959075): Running functionalization before TPU cluster formation
// is not semantics preserving and should be disabled for now.
import_config.upgrade_legacy = false;
TF_ASSIGN_OR_RETURN(
auto module_ref,
ConvertGraphToMlir(**options.graph, debug_info, *options.flib_def,

View File

@ -1143,6 +1143,29 @@ that are not a number (NaN) or infinity (Inf). Otherwise, passes `tensor` as-is.
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ClipByValueOp : TF_Op<"ClipByValue", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Clips tensor values to a specified min and max.";
let description = [{
Given a tensor `t`, this operation returns a tensor of the same type and
shape as `t` with its values clipped to `clip_value_min` and `clip_value_max`.
Any values less than `clip_value_min` are set to `clip_value_min`. Any values
greater than `clip_value_max` are set to `clip_value_max`.
}];
let arguments = (ins
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$t,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_min,
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$clip_value_max
);
let results = (outs
TensorOf<[BF16, F16, F32, F64, I16, I32, I64, I8, TF_Complex128, TF_Complex64, TF_Qint32, TF_Qint8, TF_Quint8, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$output
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_ComplexOp : TF_Op<"Complex", [NoSideEffect, ResultsBroadcastableShape]> {
let summary = "Converts two real numbers to a complex number.";

File diff suppressed because it is too large Load Diff

View File

@ -187,3 +187,199 @@ func @main() {
%write3 = "tf.TensorArrayWriteV3"(%grad3#0, %index, %value, %grad3#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return
}
// -----
// Tests while loop with access to the tensor array defined outside and its
// gradient defined inside. The gradient creation should be moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: %[[GVAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: "tf.While"(%[[VAR]], %[[SIZE]], %[[GVAR]])
%1:2 = "tf.While"(%ta#0, %size) {
body = @while_body, cond = @while_cond, device = "", is_stateless = false}
: (tensor<!tf.resource>, tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>)
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%1#0, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @while_body(%[[BARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[BARG1:.*]]: tensor<i32>, %[[BARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @while_body(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> (tensor<!tf.resource>, tensor<i32>) {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[BARG1]], %[[CONST1]])
%sub = "tf.Sub"(%arg1, %const1) : (tensor<i32>, tensor<i32>) -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[BARG0]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[BARG0]], %[[UPDATE1]])
%write = "tf.TensorArrayWriteV3"(%arg0, %sub, %elem, %flow) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %write) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[BARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[BARG2]], %[[UPDATE2]])
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %sub, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[BARG0]], %[[SUB]], %[[BARG2]]
return %arg0, %sub : tensor<!tf.resource>, tensor<i32>
}
// CHECK: func @while_cond(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<i32>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @while_cond(%arg0: tensor<!tf.resource>, %arg1: tensor<i32>) -> tensor<i32> {
// CHECK-NEXT: return %[[CARG1]]
return %arg1 : tensor<i32>
}
// -----
// Tests If op with access to the tensor array defined outside and its gradient
// defined inside. The gradient creation should be moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
%cond = "tf._SomeOp"() : () -> tensor<i1>
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: "tf.If"(%[[COND]], %[[VAR]], %[[GVAR1]], %[[GVAR2]])
%1 = "tf.If"(%cond, %ta#0) {
then_branch = @then_branch, else_branch = @else_branch, device = "", is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%1, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @then_branch(%[[TARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[TARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @then_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[TARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[TARG1]], %[[UPDATE1]])
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[TARG0]]
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @else_branch(%[[EARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[EARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @else_branch(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[EARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[EARG2]], %[[UPDATE2]])
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
// CHECK: return %[[EARG0]]
return %arg0 : tensor<!tf.resource>
}
// -----
// Tests (Stateful)PartitionedCall op with access to the tensor array defined
// outside and its gradient defined inside. The gradient creation should be
// moved outside.
// CHECK-LABEL: func @main
func @main() -> () {
// CHECK: %[[SIZE:.*]] = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// CHECK: %[[VAR:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[COND:.*]] = "tf._SomeOp"() : () -> tensor<i1>
%cond = "tf._SomeOp"() : () -> tensor<i1>
// CHECK: %[[GVAR1:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
%grad:2 = "tf.TensorArrayGradV3"(%ta#0, %ta#1) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
// CHECK: %[[GVAR2:.*]] = "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5x3xf32>>>
// CHECK: "tf.StatefulPartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
// CHECK-SAME: f = @callee_tensorarray_decomposed
%call = "tf.StatefulPartitionedCall"(%ta#0) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: "tf.PartitionedCall"(%[[VAR]], %[[GVAR1]], %[[GVAR2]])
// CHECK-SAME: f = @callee_tensorarray_decomposed
%call2 = "tf.PartitionedCall"(%call) {f = @callee, config = "", config_proto = "", executor_type = ""}
: (tensor<!tf.resource>) -> tensor<!tf.resource>
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[VAR]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: "tf.Slice"(%[[READ]],
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK-LABEL: func @callee
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
%grad:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "a"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite = "tf.TensorArrayWriteV3"(%grad#0, %const1, %elem, %grad#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
%grad2:2 = "tf.TensorArrayGradV3"(%arg0, %flow) {source = "b"} : (tensor<!tf.resource>, tensor<f32>) -> (tensor<!tf.resource>, tensor<f32>)
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
// CHECK: %[[READ2:.*]] = "tf.ReadVariableOp"(%[[CARG2]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE2:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ2]],
// CHECK: "tf.AssignVariableOp"(%[[CARG2]], %[[UPDATE2]])
// CHECK: return %[[CARG0]]
// -----
// Test the pass reports failure on unknown size.
func @main(%arg0: tensor<i32>) -> () {
// expected-error @+1 {{unknown max element count}}
%ta:2 = "tf.TensorArrayV3"(%arg0) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
return
}
// -----
// Test the pass reports failure on unknown shape.
func @main(%arg0: tensor<i32>) -> () {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{unknown element shape}}
%ta:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$unknown_rank: true", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
return
}
// -----
// Tests that the pass reports error on ambiguous tensor array.
func @main(%arg0: tensor<i1>) -> () {
%size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
%ta0:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%ta1:2 = "tf.TensorArrayV3"(%size) {dtype = f32, element_shape = "tfshape$dim { size: 3 }", dynamic_size = false, clear_after_read = true, identical_element_shapes = true, tensor_array_name = "ta"} : (tensor<i32>) -> (tensor<!tf.resource>, tensor<f32>)
%if_op = "tf.If"(%arg0, %ta0#0, %ta1#0) {then_branch = @if_then, else_branch = @if_else, is_stateless = false}
: (tensor<i1>, tensor<!tf.resource>, tensor<!tf.resource>) -> tensor<!tf.resource>
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
// expected-error @+1 {{unknown tensor array}}
%read = "tf.TensorArrayReadV3"(%if_op, %index, %ta0#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
func @if_then(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
return %arg0 : tensor<!tf.resource>
}
func @if_else(%arg0: tensor<!tf.resource>, %arg1: tensor<!tf.resource>) -> tensor<!tf.resource> {
return %arg1 : tensor<!tf.resource>
}

View File

@ -40,6 +40,20 @@ class LegalizeHloToTf : public FunctionPass<LegalizeHloToTf> {
void runOnFunction() override;
};
// Returns whether the two values are guaranteed to be broadcastable to the
// same shape, this broadcasts size 1 tensors up to any rank.
// TODO(jpienaar): Move this to more general location.
static bool AreBroadcastCompatible(Value x, Value y) {
auto x_ranked = x.getType().dyn_cast<RankedTensorType>();
auto y_ranked = y.getType().dyn_cast<RankedTensorType>();
if (!x_ranked || !y_ranked) {
return true;
}
SmallVector<int64_t, 4> resultShape;
return OpTrait::util::getBroadcastedShape(x_ranked.getShape(),
y_ranked.getShape(), resultShape);
}
#include "tensorflow/compiler/mlir/tensorflow/transforms/generated_legalize_hlo.inc"
/// Performs the lowering to XLA dialect.

View File

@ -20,14 +20,16 @@ include "mlir/Dialect/StandardOps/IR/Ops.td"
include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td"
include "tensorflow/compiler/mlir/xla/ir/hlo_ops.td"
def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>;
def : Pat<(HLO_ConstOp $value), (TF_ConstOp $value)>;
//===----------------------------------------------------------------------===//
// Binary op patterns.
//===----------------------------------------------------------------------===//
class DirectBinaryPat<Op FromOp, Op ToOp>
: Pat<(FromOp $l, $r, $_), (ToOp $l, $r)>;
// Check that two values can be broadcasted together
// TODO(jpienaar): Move somewhere more general
def AreBroadcastCompatible : Constraint<CPred<"AreBroadcastCompatible($0, $1)">,
"types must be broadcastable">;
foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
[HLO_DivOp, TF_DivOp],
@ -37,24 +39,41 @@ foreach fromToBinPair = [[HLO_AddOp, TF_AddV2Op],
[HLO_MulOp, TF_MulOp],
[HLO_PowOp, TF_PowOp],
[HLO_DivOp, TF_RealDivOp],
[HLO_SubOp, TF_SubOp]] in
def : DirectBinaryPat<fromToBinPair[0], fromToBinPair[1]>;
[HLO_SubOp, TF_SubOp],
[HLO_Atan2Op, TF_Atan2Op],
[HLO_RemOp, TF_ModOp]] in
def : Pat<(fromToBinPair[0] $l, $r, $_), (fromToBinPair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def LowerRightShiftSigned :
Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(SignedIntTensor $r)]>;
foreach pair = [[HLO_AndOp, TF_BitwiseAndOp],
[HLO_OrOp, TF_BitwiseOrOp],
[HLO_XorOp, TF_BitwiseXorOp]] in
def : Pat<(pair[0] TF_IntTensor:$l, TF_IntTensor:$r, $_), (pair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r)>;
foreach pair = [[HLO_AndOp, TF_LogicalAndOp],
[HLO_OrOp, TF_LogicalOrOp]] in
def : Pat<(pair[0] I1Tensor:$l, I1Tensor:$r, $_), (pair[1] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightArithmeticOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_ShiftRightLogicalOp $l, $r, $_), (TF_RightShiftOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
def : Pat<(HLO_FloorOp (HLO_DivOp $l, $r, $_)), (TF_FloorDivOp $l, $r),
[(AreBroadcastCompatible $l, $r)]>;
//===----------------------------------------------------------------------===//
// Unary op patterns.
//===----------------------------------------------------------------------===//
foreach Mapping = [
[HLO_AbsOp, TF_AbsOp],
foreach Mapping = [[HLO_AbsOp, TF_AbsOp],
[HLO_BitcastConvertOp, TF_BitcastOp],
[HLO_CeilOp, TF_CeilOp],
[HLO_CosOp, TF_CosOp],
[HLO_ExpOp, TF_ExpOp],
[HLO_Expm1Op, TF_Expm1Op],
[HLO_FloorOp, TF_FloorOp],
[HLO_ImagOp, TF_ImagOp],
[HLO_IsFiniteOp, TF_IsFiniteOp],
@ -65,8 +84,46 @@ foreach Mapping = [
[HLO_RealOp, TF_RealOp],
[HLO_RsqrtOp, TF_RsqrtOp],
[HLO_SinOp, TF_SinOp],
[HLO_SignOp, TF_SignOp],
[HLO_SqrtOp, TF_SqrtOp],
[HLO_TanhOp, TF_TanhOp],
] in {
def : Pat<(Mapping[0] $input), (Mapping[1] $input)>;
}
[HLO_TanhOp, TF_TanhOp]] in
def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>;
def : Pat<(HLO_AbsOp TF_ComplexTensor:$arg), (TF_ComplexAbsOp $arg)>;
def : Pat<(HLO_BroadcastOp $arg, $shape),
(TF_BroadcastToOp $arg, (TF_ConstOp $shape))>;
def : Pat<(HLO_TransposeOp $arg, $permutation),
(TF_TransposeOp $arg, (TF_ConstOp $permutation))>;
def : Pat<(HLO_ReverseOp $op, $dims), (TF_ReverseV2Op $op, (TF_ConstOp $dims))>;
//===----------------------------------------------------------------------===//
// Ternary op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ClampOp $min, $arg, $max),
(TF_MaximumOp (TF_MinimumOp $arg, $max), $min)>;
def : Pat<(HLO_SelectOp $cond, $t, $e), (TF_SelectOp $cond, $t, $e)>;
//===----------------------------------------------------------------------===//
// Variadic op patterns.
//===----------------------------------------------------------------------===//
def : Pat<(HLO_ConcatenateOp $inputs, $dim),
(TF_ConcatV2Op $inputs, (TF_ConstOp $dim))>;
//===----------------------------------------------------------------------===//
// Compare op patterns.
//===----------------------------------------------------------------------===//
foreach p = [[TF_EqualOp, HLO_COMPARISON_DIRECTION_EQ],
[TF_NotEqualOp, HLO_COMPARISON_DIRECTION_NE]] in
def : Pat<(HLO_CompareOp $l, $r, $_, p[1]), (p[0] $l, $r, ConstBoolAttrTrue),
[(AreBroadcastCompatible $l, $r)]>;
foreach pair = [[TF_GreaterEqualOp, HLO_COMPARISON_DIRECTION_GE],
[TF_GreaterOp, HLO_COMPARISON_DIRECTION_GT],
[TF_LessEqualOp, HLO_COMPARISON_DIRECTION_LE],
[TF_LessOp, HLO_COMPARISON_DIRECTION_LT]] in
def : Pat<(HLO_CompareOp $l, $r, $_, pair[1]), (pair[0] $l, $r),
[(AreBroadcastCompatible $l, $r)]>;

View File

@ -19,7 +19,8 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
@ -55,6 +56,8 @@ namespace {
namespace cutil = TF::collection_ops_util;
using std::string;
// A pass that converts tensor array operations to tensor operations and
// read/assign ops on local variables. A later resource lifting pass can further
// remove the local variables.
@ -85,7 +88,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split,
return split.emitOpError("unknown or invalid split tensor shape");
}
int64_t length = buffer_type.getDimSize(0) / *count;
for (auto len : lengths_const.value().getValues<APInt>()) {
for (const auto& len : lengths_const.value().getValues<APInt>()) {
if (length == len.getSExtValue()) continue;
return split.emitOpError("different split lengths are not supported");
}
@ -145,7 +148,7 @@ struct TensorArrayStats {
// this is a gradient.
bool accumulate_on_write;
// Maps from a gradient source string to the local variable to the gradient.
llvm::SmallDenseMap<llvm::StringRef, Value> grads;
llvm::StringMap<Value> grads;
};
LogicalResult HandleTensorArrayV3Op(
@ -224,10 +227,7 @@ LogicalResult HandleTensorArrayWriteV3Op(
cutil::GetElement(index_reshape, buffer, builder, write.getLoc(),
/*keep_slice_shape=*/true);
// Add a size-1 leading dimension to elem.
for (auto dim : buffer.getType().cast<RankedTensorType>().getShape())
LOG(ERROR) << " buffer : " << dim;
auto slice_type = original_elem.getType().cast<RankedTensorType>();
for (auto dim : slice_type.getShape()) LOG(ERROR) << " resahpe : " << dim;
elem = builder.create<TF::ReshapeOp>(
write.getLoc(), ArrayRef<Type>{slice_type},
ArrayRef<Value>{elem, cutil::GetR1Const(slice_type.getShape(), builder,
@ -339,6 +339,26 @@ LogicalResult HandleTensorArraySizeV3Op(
return success();
}
LogicalResult CreateAndInitializeGradVariable(Type local_var_type,
Operation* op, Value* var) {
OpBuilder builder(op);
*var = builder.create<TF::MlirLocalVarOp>(
op->getLoc(), ArrayRef<Type>{local_var_type}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
Value buffer;
auto buffer_type = getElementTypeOrSelf(local_var_type)
.cast<TF::ResourceType>()
.getSubtypes()[0]
.cast<RankedTensorType>();
if (failed(cutil::CreateInitBufferValue(
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op,
buffer_type.getElementType(), builder, &buffer))) {
return failure();
}
cutil::WriteLocalVariable(*var, buffer, builder, op->getLoc());
return success();
}
LogicalResult HandleTensorArrayGradV3Op(
TF::TensorArrayGradV3Op grad,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
@ -347,26 +367,17 @@ LogicalResult HandleTensorArrayGradV3Op(
Value grad_var;
auto sit = stats->find(local_var);
if (sit == stats->end()) return grad.emitOpError("unknown tensor array");
auto emplace_res = sit->getSecond().grads.try_emplace(grad.source(), Value());
auto emplace_res =
sit->getSecond().grads.try_emplace(grad.source().str(), Value());
if (!emplace_res.second) {
// If the source has been assigned a grad, use it.
grad_var = emplace_res.first->getSecond();
grad_var = emplace_res.first->second;
} else {
grad_var = builder.create<TF::MlirLocalVarOp>(
grad.getLoc(), ArrayRef<Type>{local_var.getType()}, ArrayRef<Value>{},
ArrayRef<NamedAttribute>{});
Value buffer;
auto buffer_type = getElementTypeOrSelf(local_var.getType())
.cast<TF::ResourceType>()
.getSubtypes()[0]
.cast<RankedTensorType>();
if (failed(cutil::CreateInitBufferValue(
buffer_type.getShape().drop_front(), buffer_type.getDimSize(0),
grad, buffer_type.getElementType(), builder, &buffer))) {
if (failed(CreateAndInitializeGradVariable(local_var.getType(), grad,
&grad_var))) {
return failure();
}
cutil::WriteLocalVariable(grad_var, buffer, builder, grad.getLoc());
emplace_res.first->getSecond() = grad_var;
emplace_res.first->second = grad_var;
// Write to a grad accumulates with previous writes.
(*stats)[grad_var].accumulate_on_write = true;
}
@ -409,36 +420,454 @@ LogicalResult HandleTensorArrayScatterV3Op(
return success();
}
LogicalResult DecomposeTensorArrayOps(Block* block, ModuleOp module) {
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
// Updates func's type according to its current arguments and return values.
void UpdateFuncType(FuncOp func) {
llvm::SmallVector<Type, 8> arg_types;
for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
func.setType(FunctionType::get(
arg_types,
llvm::to_vector<8>(func.front().getTerminator()->getOperandTypes()),
func.getContext()));
}
// Finds the accessed gradient sources for each tensor array argument.
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> AccessedGradients(
ArrayRef<FuncOp> funcs, ModuleOp module) {
llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>> result;
llvm::SmallDenseMap<int64_t, llvm::StringSet<>> result_sets;
auto insert = [&](Value v, const string& source) {
auto arg = v.cast<BlockArgument>();
if (!arg) return;
auto insert_res = result_sets[arg.getArgNumber()].insert(source);
if (!insert_res.second) return;
result[arg.getArgNumber()].push_back(source);
};
for (FuncOp func : funcs) {
for (auto& op : func.front().getOperations()) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
continue;
}
if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
insert(grad.handle(), grad.source().str());
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
auto body = module.lookupSymbol<FuncOp>(while_op.body());
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
for (const auto& entry : AccessedGradients({body, cond}, module)) {
for (const string& source : entry.getSecond()) {
insert(while_op.getOperand(entry.getFirst()), source);
}
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
for (const auto& entry :
AccessedGradients({then_branch, else_branch}, module)) {
for (const string& source : entry.getSecond()) {
insert(if_op.getOperand(entry.getFirst() + 1), source);
}
}
} else if (auto pc = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!pc.f().isa<FlatSymbolRefAttr>()) continue;
auto callee = module.lookupSymbol<FuncOp>(pc.f().getRootReference());
for (const auto& entry : AccessedGradients({callee}, module)) {
for (const string& source : entry.getSecond()) {
insert(pc.getOperand(entry.getFirst()), source);
}
}
} else if (auto spc =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
auto callee = module.lookupSymbol<FuncOp>(spc.f());
for (const auto& entry : AccessedGradients({callee}, module)) {
for (const string& source : entry.getSecond()) {
insert(spc.getOperand(entry.getFirst()), source);
}
}
}
}
}
return result;
}
// Contains cached information for decomposed callee functions for (stateful)
// partitioned call ops.
struct PartitionedCallTensorArrayOpsInfo {
bool signature_change;
FuncOp decomposed_callee;
llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<string, 4>>, 4>
arg_grads;
llvm::SmallVector<std::pair<int64_t, int64_t>, 4> ret_forward_input;
};
// Updates a called function's input signature by adjusting resource types, and
// adding required gradient arguments.
void ChangeFunctionInputSignature(
FuncOp func,
const llvm::SmallDenseMap<int64_t, llvm::SmallVector<string, 4>>& grads,
llvm::function_ref<Type(int64_t)> ta_arg_buffer_type,
llvm::function_ref<bool(int64_t)> ta_accumulate_on_write,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats) {
int64_t original_args = func.getNumArguments();
for (int64_t argnum = 0; argnum < original_args; ++argnum) {
auto arg = func.getArgument(argnum);
Type t = ta_arg_buffer_type(argnum);
if (!t) continue;
arg.setType(t);
auto grad_it = grads.find(argnum);
if (grad_it == grads.end()) continue;
llvm::StringMap<Value> grads_map;
for (const string& source : grad_it->getSecond()) {
auto g = func.front().addArgument(t);
(*stats)[g].accumulate_on_write = true;
grads_map[source] = g;
}
auto& stat = (*stats)[arg];
stat.accumulate_on_write = ta_accumulate_on_write(argnum);
stat.grads = std::move(grads_map);
}
UpdateFuncType(func);
}
LogicalResult DecomposeTensorArrayOps(
Block*, ModuleOp, llvm::SmallDenseMap<Value, TensorArrayStats>*,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*);
LogicalResult HandleWhileOp(
TF::WhileOp while_op, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto body = module.lookupSymbol<FuncOp>(while_op.body());
auto cond = module.lookupSymbol<FuncOp>(while_op.cond());
auto grads = AccessedGradients({body, cond}, module);
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(while_op.getOperand(index));
if (it == stats->end()) return nullptr;
return it->getFirst().getType();
};
auto ta_accumulate_on_write = [&](int64_t index) {
auto it = stats->find(while_op.getOperand(index));
if (it == stats->end()) return false;
return it->getSecond().accumulate_on_write;
};
llvm::SmallDenseMap<Value, TensorArrayStats> body_stats;
ChangeFunctionInputSignature(body, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &body_stats);
llvm::SmallDenseMap<Value, TensorArrayStats> cond_stats;
ChangeFunctionInputSignature(cond, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &cond_stats);
if (failed(DecomposeTensorArrayOps(&body.front(), module, &body_stats,
decomposed_partitioned_call_callees)) ||
failed(DecomposeTensorArrayOps(&cond.front(), module, &cond_stats,
decomposed_partitioned_call_callees))) {
return failure();
}
if (body_stats.empty() && cond_stats.empty()) return success();
auto old_body_ret = body.front().getTerminator();
auto new_retvals = llvm::to_vector<8>(old_body_ret->getOperands());
for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
if (!ta_arg_buffer_type(i)) continue;
auto retval = old_body_ret->getOperand(i);
auto arg = retval.dyn_cast<BlockArgument>();
if (!arg) {
return while_op.emitOpError(
"output tensor array does not alias input in a while loop");
}
for (const string& source : grads[i]) {
new_retvals.push_back(body_stats[arg].grads[source]);
}
}
OpBuilder(old_body_ret).create<ReturnOp>(old_body_ret->getLoc(), new_retvals);
old_body_ret->erase();
UpdateFuncType(body);
// Recreate the while op.
auto operands = llvm::to_vector<8>(while_op.getOperands());
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
auto grad_it = grads.find(i);
auto& stat = (*stats)[operands[i]];
if (grad_it == grads.end()) continue;
for (const string& source : grad_it->getSecond()) {
auto it = stat.grads.find(source);
if (it != stat.grads.end()) {
operands.push_back(it->second);
} else {
Value grad_var;
if (failed(CreateAndInitializeGradVariable(operands[i].getType(),
while_op, &grad_var))) {
return failure();
}
stat.grads[source] = grad_var;
operands.push_back(grad_var);
}
}
}
OpBuilder builder(while_op);
auto new_while =
builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(),
operands, while_op.getAttrs());
// Clear the output shapes as it is not needed for XLA lowering.
new_while.setAttr("output_shapes", builder.getArrayAttr({}));
for (int64_t i = 0; i < while_op.getNumOperands(); ++i) {
if (ta_arg_buffer_type(i)) {
while_op.getResult(i).replaceAllUsesWith(while_op.getOperand(i));
} else {
while_op.getResult(i).replaceAllUsesWith(new_while.getResult(i));
}
}
while_op.erase();
return success();
}
LogicalResult HandleIfOp(
TF::IfOp if_op, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto then_branch = module.lookupSymbol<FuncOp>(if_op.then_branch());
auto else_branch = module.lookupSymbol<FuncOp>(if_op.else_branch());
auto grads = AccessedGradients({then_branch, else_branch}, module);
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(if_op.getOperand(index + 1));
if (it == stats->end()) return nullptr;
return it->getFirst().getType();
};
auto ta_accumulate_on_write = [&](int64_t index) {
auto it = stats->find(if_op.getOperand(index + 1));
if (it == stats->end()) return false;
return it->getSecond().accumulate_on_write;
};
llvm::SmallDenseMap<Value, TensorArrayStats> then_stats;
ChangeFunctionInputSignature(then_branch, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &then_stats);
llvm::SmallDenseMap<Value, TensorArrayStats> else_stats;
ChangeFunctionInputSignature(else_branch, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &else_stats);
if (failed(DecomposeTensorArrayOps(&then_branch.front(), module, &then_stats,
decomposed_partitioned_call_callees)) ||
failed(DecomposeTensorArrayOps(&else_branch.front(), module, &else_stats,
decomposed_partitioned_call_callees))) {
return failure();
}
if (then_stats.empty() && else_stats.empty()) return success();
// Recreate the if op.
auto operands = llvm::to_vector<8>(if_op.getOperands());
for (int64_t i = 0; i < if_op.getNumOperands() - 1; ++i) {
auto grad_it = grads.find(i);
auto& stat = (*stats)[operands[i + 1]];
if (grad_it == grads.end()) continue;
for (const string& source : grad_it->getSecond()) {
auto it = stat.grads.find(source);
if (it != stat.grads.end()) {
operands.push_back(it->second);
} else {
Value grad_var;
if (failed(CreateAndInitializeGradVariable(operands[i + 1].getType(),
if_op, &grad_var))) {
return failure();
}
stat.grads[source] = grad_var;
operands.push_back(grad_var);
}
}
}
OpBuilder builder(if_op);
auto new_if = builder.create<TF::IfOp>(if_op.getLoc(),
then_branch.getType().getResults(),
operands, if_op.getAttrs());
// Clear the output shapes as it is not needed for XLA lowering.
new_if.setAttr("output_shapes", builder.getArrayAttr({}));
auto ret_forwards_input = [](FuncOp f, int64_t ret_ind) -> int64_t {
auto retval = f.front().getTerminator()->getOperand(ret_ind);
auto arg = retval.dyn_cast<BlockArgument>();
if (!arg) return -1;
return arg.getArgNumber();
};
for (int64_t i = 0; i < if_op.getNumResults(); ++i) {
if (!getElementTypeOrSelf(if_op.getResult(i).getType())
.isa<TF::ResourceType>()) {
if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i));
continue;
}
int64_t then_forward_input = ret_forwards_input(then_branch, i);
int64_t else_foward_input = ret_forwards_input(else_branch, i);
if (then_forward_input != else_foward_input || then_forward_input < 0) {
return if_op.emitOpError(
"branches do not forward the same input resource");
}
if_op.getResult(i).replaceAllUsesWith(
if_op.getOperand(then_forward_input + 1));
}
if_op.erase();
return success();
}
template <typename CallOp>
LogicalResult HandlePartitionedCallOp(
CallOp call, FuncOp callee, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
callee, PartitionedCallTensorArrayOpsInfo());
auto& info = emplace_res.first->getSecond();
// Recreates the call op with info.
auto recreate_caller = [&]() -> LogicalResult {
auto new_operands = llvm::to_vector<8>(call.getOperands());
for (const auto& entry : info.arg_grads) {
auto it = stats->find(call.getOperand(entry.first));
if (it == stats->end()) return call.emitOpError("unknown tensor array");
for (const string& source : entry.second) {
auto grad_it = it->getSecond().grads.find(source);
if (grad_it != it->getSecond().grads.end()) {
new_operands.push_back(grad_it->second);
} else {
Value grad_var;
if (failed(CreateAndInitializeGradVariable(it->getFirst().getType(),
call, &grad_var))) {
return failure();
}
it->getSecond().grads[source] = grad_var;
new_operands.push_back(grad_var);
}
}
}
OpBuilder builder(call);
auto new_call = builder.create<CallOp>(
call.getLoc(), info.decomposed_callee.getType().getResults(),
new_operands, call.getAttrs());
new_call.setAttr(
"f", builder.getSymbolRefAttr(
const_cast<FuncOp&>(info.decomposed_callee).getName()));
for (const auto& entry : info.ret_forward_input) {
call.getResult(entry.first)
.replaceAllUsesWith(call.getOperand(entry.second));
}
call.replaceAllUsesWith(new_call);
call.erase();
return success();
};
if (!emplace_res.second) {
// This callee was handled before.
if (!info.signature_change) return success();
return recreate_caller();
}
// Rewrite the callee on a cloned function.
info.signature_change = false;
auto ta_arg_buffer_type = [&](int64_t index) -> Type {
auto it = stats->find(call.getOperand(index));
if (it == stats->end()) return nullptr;
info.signature_change = true;
return it->getFirst().getType();
};
auto ta_accumulate_on_write = [&](int64_t index) {
auto it = stats->find(call.getOperand(index));
if (it == stats->end()) return false;
return it->getSecond().accumulate_on_write;
};
auto callee_clone = callee.clone();
auto grads = AccessedGradients({callee_clone}, module);
for (int64_t i = 0; i < callee_clone.getNumArguments(); ++i) {
auto it = grads.find(i);
if (it == grads.end()) continue;
info.arg_grads.emplace_back(i, it->getSecond());
}
llvm::SmallDenseMap<Value, TensorArrayStats> callee_stats;
ChangeFunctionInputSignature(callee_clone, grads, ta_arg_buffer_type,
ta_accumulate_on_write, &callee_stats);
if (failed(DecomposeTensorArrayOps(&callee_clone.front(), module,
&callee_stats,
decomposed_partitioned_call_callees))) {
return failure();
}
for (int64_t i = 0; i < call.getNumResults(); ++i) {
auto ret = callee_clone.front().getTerminator()->getOperand(i);
if (!getElementTypeOrSelf(ret.getType()).isa<TF::ResourceType>()) continue;
auto arg = ret.dyn_cast<BlockArgument>();
if (!arg) continue;
info.ret_forward_input.emplace_back(i, arg.getArgNumber());
}
if (!info.signature_change) {
// Signature is not modified. We do not need to keep two copies.
info.signature_change = false;
auto name = callee.getName();
callee.erase();
callee_clone.setName(name);
SymbolTable(module).insert(callee_clone);
} else {
info.decomposed_callee = callee_clone;
// Add the clone with a new name.
auto name =
llvm::formatv("{0}_{1}", callee.getName(), "tensorarray_decomposed")
.str();
callee_clone.setName(name);
SymbolTable(module).insert(callee_clone);
}
if (info.signature_change) return recreate_caller();
return success();
}
LogicalResult DecomposeTensorArrayOps(
Block* block, ModuleOp module,
llvm::SmallDenseMap<Value, TensorArrayStats>* stats,
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>*
decomposed_partitioned_call_callees) {
for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
if (llvm::isa<TF::IdentityOp>(&op) || llvm::isa<TF::IdentityNOp>(&op)) {
op.replaceAllUsesWith(op.getOperands());
op.erase();
} else if (auto ta = llvm::dyn_cast<TF::TensorArrayV3Op>(&op)) {
if (failed(HandleTensorArrayV3Op(ta, module, &stats))) {
if (failed(HandleTensorArrayV3Op(ta, module, stats))) {
return failure();
}
} else if (auto read = llvm::dyn_cast<TF::TensorArrayReadV3Op>(&op)) {
if (failed(HandleTensorArrayReadV3Op(read, stats))) return failure();
if (failed(HandleTensorArrayReadV3Op(read, *stats))) return failure();
} else if (auto write = llvm::dyn_cast<TF::TensorArrayWriteV3Op>(&op)) {
if (failed(HandleTensorArrayWriteV3Op(write, stats))) return failure();
if (failed(HandleTensorArrayWriteV3Op(write, *stats))) return failure();
} else if (auto concat = llvm::dyn_cast<TF::TensorArrayConcatV3Op>(&op)) {
if (failed(HandleTensorArrayConcatV3Op(concat, stats))) return failure();
if (failed(HandleTensorArrayConcatV3Op(concat, *stats))) return failure();
} else if (auto split = llvm::dyn_cast<TF::TensorArraySplitV3Op>(&op)) {
if (failed(HandleTensorArraySplitV3Op(split, stats))) return failure();
if (failed(HandleTensorArraySplitV3Op(split, *stats))) return failure();
} else if (auto size = llvm::dyn_cast<TF::TensorArraySizeV3Op>(&op)) {
if (failed(HandleTensorArraySizeV3Op(size, stats))) return failure();
if (failed(HandleTensorArraySizeV3Op(size, *stats))) return failure();
} else if (auto grad = llvm::dyn_cast<TF::TensorArrayGradV3Op>(&op)) {
if (failed(HandleTensorArrayGradV3Op(grad, &stats))) return failure();
if (failed(HandleTensorArrayGradV3Op(grad, stats))) return failure();
} else if (auto gather = llvm::dyn_cast<TF::TensorArrayGatherV3Op>(&op)) {
if (failed(HandleTensorArrayGatherV3Op(gather, stats))) return failure();
if (failed(HandleTensorArrayGatherV3Op(gather, *stats))) return failure();
} else if (auto scatter = llvm::dyn_cast<TF::TensorArrayScatterV3Op>(&op)) {
if (failed(HandleTensorArrayScatterV3Op(scatter, stats))) {
if (failed(HandleTensorArrayScatterV3Op(scatter, *stats))) {
return failure();
}
} else if (auto close = llvm::dyn_cast<TF::TensorArrayCloseV3Op>(&op)) {
close.erase();
} else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
if (failed(HandleWhileOp(while_op, module, stats,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
if (failed(HandleIfOp(if_op, module, stats,
decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
if (!pcall.f().isa<FlatSymbolRefAttr>()) {
return pcall.emitOpError(
"TensorArray decomposition does not support call with nested "
"references.");
}
if (failed(HandlePartitionedCallOp(
pcall, module.lookupSymbol<FuncOp>(pcall.f().getRootReference()),
module, stats, decomposed_partitioned_call_callees))) {
return failure();
}
} else if (auto spcall =
llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
if (failed(HandlePartitionedCallOp(
spcall, module.lookupSymbol<FuncOp>(spcall.f()), module, stats,
decomposed_partitioned_call_callees))) {
return failure();
}
}
}
return success();
@ -448,7 +877,11 @@ void TensorArrayOpsDecompositionPass::runOnModule() {
auto module = getModule();
auto main = module.lookupSymbol<FuncOp>("main");
if (!main) return;
if (failed(DecomposeTensorArrayOps(&main.front(), module))) {
llvm::SmallDenseMap<Value, TensorArrayStats> stats;
llvm::SmallDenseMap<FuncOp, PartitionedCallTensorArrayOpsInfo>
decomposed_partitioned_call_callees;
if (failed(DecomposeTensorArrayOps(&main.front(), module, &stats,
&decomposed_partitioned_call_callees))) {
signalPassFailure();
}
}

View File

@ -261,6 +261,7 @@ Status ConvertMLIRToXlaComputation(
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
tf2xla.addPass(mlir::TF::CreateTensorListOpsDecompositionPass());
tf2xla.addPass(mlir::TF::CreateStackOpsDecompositionPass());
tf2xla.addPass(mlir::TF::CreateTensorArrayOpsDecompositionPass());
tf2xla.addPass(mlir::TFDevice::CreateDecomposeResourceOpsPass());
tf2xla.addPass(mlir::TF::CreatePromoteResourcesToArgsPass());
// LegalizeTFControlFlow encapsulates arguments for control flow operations

View File

@ -732,6 +732,7 @@ genrule(
outs = ["operator_writers.inc"],
cmd = ("$(location :operator_writer_gen) " +
"-I external/llvm-project/mlir/include " +
"-I external/org_tensorflow " +
"$(location //tensorflow/compiler/mlir/xla:ir/hlo_ops.td) " +
" -o $@"),
tools = [":operator_writer_gen"],

View File

@ -667,10 +667,11 @@ static Device* LookupDevice(const PyLocalClient& client, int device_id) {
PyLocalExecutable::PyLocalExecutable(
std::vector<std::unique_ptr<LocalExecutable>> executables,
DeviceAssignment device_assignment, PyLocalClient* client)
bool tuple_arguments, DeviceAssignment device_assignment,
PyLocalClient* client)
: client_(client),
device_assignment_(
std::make_shared<DeviceAssignment>(device_assignment)) {
device_assignment_(std::make_shared<DeviceAssignment>(device_assignment)),
tuple_arguments_(tuple_arguments) {
executables_.reserve(executables.size());
for (auto& executable : executables) {
executables_.emplace_back(std::move(executable));
@ -727,7 +728,7 @@ PyLocalExecutable::ExecuteHelper(
std::unique_ptr<PyLocalBuffer> tuple_buffer;
std::vector<PyLocalBuffer*> tupled_arguments;
if (options.tuple_arguments) {
if (options.tuple_arguments || tuple_arguments_) {
TF_ASSIGN_OR_RETURN(tuple_buffer, PyLocalBuffer::MakeTuple(
argument_handles, client_, device));
tupled_arguments = {tuple_buffer.get()};
@ -1037,7 +1038,8 @@ PyLocalExecutable::Compile(const XlaComputation& computation,
build_options));
return absl::make_unique<PyLocalExecutable>(
std::move(local_executables), build_options.device_assignment(), client);
std::move(local_executables), options.tuple_arguments,
build_options.device_assignment(), client);
}
} // namespace xla

View File

@ -316,6 +316,10 @@ struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the arguments to the computation will be wrapped in a tuple and
// passed as a single parameter.
bool tuple_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
};
@ -340,7 +344,8 @@ class PyLocalExecutable {
CompileOptions options);
PyLocalExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
DeviceAssignment device_assignment, PyLocalClient* client);
bool tuple_arguments, DeviceAssignment device_assignment,
PyLocalClient* client);
PyLocalClient* client() const { return client_; }
@ -404,6 +409,10 @@ class PyLocalExecutable {
std::vector<std::shared_ptr<LocalExecutable>> executables_;
std::shared_ptr<DeviceAssignment> device_assignment_;
// True if the executables were compiled expecting arguments in a single
// tuple.
const bool tuple_arguments_;
// The replica and partition indices of device_assignment_ to be run by this
// client. On single-host platforms without partitioning, this is all replicas
// (i.e. local_logical_device_ids_[i] = (i, 0)), but this may not be the case

View File

@ -44,6 +44,12 @@ class CpuTransferManager : public GenericTransferManager {
const Shape& literal_shape,
MutableBorrowingLiteral literal) override;
bool CanShapedBufferBeAccessedNow(
se::StreamExecutor* executor,
const ShapedBuffer& device_buffer) const override {
return true;
}
private:
Status TransferBufferToInfeed(se::StreamExecutor* executor, int64 size,
const void* source);

View File

@ -518,7 +518,9 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
<< "Invalid LLVM IR before optimizations:\n"
<< err_stream.str()
<< "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. "
"Rerun with --xla_dump_to to get the IR. ";
"Rerun with --xla_dump_to to get the IR and looks for files with "
"name containing: *"
<< FilenameFor(*module, "", "") << "*";
}
GpuVersion gpu_version = GetGpuVersion(stream_exec);

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include <algorithm>
#include <iterator>
#include <stack>
#include <vector>
@ -66,6 +67,25 @@ bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
return false;
}
std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
std::vector<int64> relative_order;
for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
if (shape.dimensions(dim) > 1) {
relative_order.push_back(dim);
}
}
// Now normalize the dimensions to values between 0 and true rank - 1.
std::vector<int64> sorted_dims = relative_order;
std::sort(sorted_dims.begin(), sorted_dims.end());
for (int64& dim : relative_order) {
int64 sorted_index = std::distance(
sorted_dims.begin(),
std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
dim = sorted_index;
}
return relative_order;
}
} // namespace
bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
@ -73,17 +93,20 @@ bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
std::vector<HloInstruction*> params;
AppendParams(producer, &params);
AppendParams(reduce, &params);
int64 max_rank = -1;
const Layout* max_rank_layout;
int64 max_true_rank = -1;
std::vector<int64> max_rank_order;
for (HloInstruction* param : params) {
if (param->shape().IsArray() && param->shape().rank() > max_rank) {
max_rank = param->shape().rank();
max_rank_layout = &param->shape().layout();
if (param->shape().IsArray() &&
ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
max_true_rank = ShapeUtil::TrueRank(param->shape());
max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
}
}
return absl::c_all_of(params, [&](HloInstruction* param) {
return (!param->shape().IsArray()) || (param->shape().rank() < max_rank) ||
(LayoutUtil::Equal(param->shape().layout(), *max_rank_layout));
return !param->shape().IsArray() ||
ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
max_rank_order;
});
}

View File

@ -91,6 +91,44 @@ TEST_F(GpuFusibleTest,
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
}
TEST_F(GpuFusibleTest,
LayoutsAreReduceInputFusionFriendly_MixedLayoutProducerWithTrivialDim) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
mixed_input_layouts_computation {
p0.1 = f16[128,1,32,32]{1,3,2,0} parameter(0)
p1.1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
copy = f16[128,1,32,32]{1,3,2,0} copy(p1.1)
c0 = f16[] constant(0)
broadcast = f16[128,1,32,32]{1,3,2,0} broadcast(c0), dimensions={}
greater-than = pred[128,1,32,32]{1,3,2,0} compare(copy, broadcast), direction=GT
ROOT root = f16[128,1,32,32]{1,3,2,0} select(greater-than, p0.1, broadcast)
}
fused_reduce {
p0.2 = f16[128,1,32,32]{1,3,2,0} parameter(0)
convert = f32[128,1,32,32]{1,3,2,0} convert(p0.2)
c0.2 = f32[] constant(0)
ROOT reduce = f32[1]{0} reduce(convert, c0.2), dimensions={0,2,3}, to_apply=scalar_add
}
ENTRY entry {
p0 = f16[128,1,32,32]{1,3,2,0} parameter(0)
p1 = f16[128,1,32,32]{3,2,1,0} parameter(1)
loop_fusion = f16[128,1,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=mixed_input_layouts_computation
reduce_fusion = f32[1]{0} fusion(loop_fusion), kind=kInput, calls=fused_reduce
ROOT root = (f32[1]{0}, f16[128,1,32,32]{1,3,2,0}) tuple(reduce_fusion, loop_fusion)
})"))
.ValueOrDie();
SCOPED_TRACE(module->ToString());
const HloInstruction* reduce_fusion =
module->entry_computation()->root_instruction()->operand(0);
ASSERT_EQ(reduce_fusion->fused_expression_root()->opcode(),
HloOpcode::kReduce);
const HloInstruction* loop_fusion =
module->entry_computation()->root_instruction()->operand(1);
ASSERT_EQ(loop_fusion->fused_expression_root()->opcode(), HloOpcode::kSelect);
EXPECT_TRUE(
LayoutsAreReduceInputFusionFriendly(*loop_fusion, *reduce_fusion));
}
TEST_F(GpuFusibleTest, LayoutsAreReduceInputFusionFriendly_CopyProducer) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
fused_reduce {
@ -152,17 +190,18 @@ TEST_F(GpuFusibleTest,
}
TEST_F(GpuFusibleTest,
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumRanksParamsOnly) {
LayoutsAreReduceInputFusionFriendly_ConsiderMaximumTrueRanksParamsOnly) {
auto module = ParseAndReturnVerifiedModule(absl::StrCat(kModulePrefix, R"(
broadcasting_computation {
p0.1 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
p1.1 = f32[128]{0} parameter(1)
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(p1.1), dimensions={0}
p1.1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
reshape = f32[128]{0} reshape(p1.1)
broadcast = f32[128,1024,32,32]{1,3,2,0} broadcast(reshape), dimensions={0}
ROOT add = f32[128,1024,32,32]{1,3,2,0} add(p0.1, broadcast)
}
ENTRY entry {
p0 = f32[128,1024,32,32]{1,3,2,0} parameter(0)
p1 = f32[128]{0} parameter(1)
p1 = f32[1,128,1,1]{3,2,1,0} parameter(1)
loop_fusion = f32[128,1024,32,32]{1,3,2,0} fusion(p0, p1), kind=kLoop, calls=broadcasting_computation
c0.2 = f32[] constant(0)
ROOT reduce = f32[1024]{0} reduce(loop_fusion, c0.2), dimensions={0,2,3}, to_apply=scalar_add

View File

@ -27,6 +27,7 @@ cc_library(
srcs = ["conv_emitter.cc"],
hdrs = ["conv_emitter.h"],
deps = [
":conv_emitter_transforms",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
@ -39,6 +40,23 @@ cc_library(
],
)
cc_library(
name = "conv_emitter_transforms",
srcs = ["conv_emitter_transforms.cc"],
hdrs = ["conv_emitter_transforms.h"],
deps = [
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:support",
"@llvm-project//mlir:Affine",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
"@llvm-project//mlir:TransformUtils",
],
)
tf_cc_test(
name = "conv_emitter_test",
srcs = ["conv_emitter_test.cc"],

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h"
#include "tensorflow/compiler/xla/window_util.h"
namespace xla {
@ -109,48 +110,6 @@ ShapeInfo GetShapeInfo(
return shape_info;
}
bool IsSimpleLoop(mlir::AffineForOp loop) {
return loop.getLowerBoundMap().isSingleConstant() &&
loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
std::next(loop.region().begin()) == loop.region().end();
}
struct BoundAffineMap {
mlir::AffineMap affine_map;
std::vector<mlir::Value> operands;
};
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
return {load.getAffineMap(),
std::vector<mlir::Value>(load.getMapOperands().begin(),
load.getMapOperands().end())};
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
return {store.getAffineMap(),
std::vector<mlir::Value>(store.getMapOperands().begin(),
store.getMapOperands().end())};
} else {
CHECK(false);
}
}
mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
BoundAffineMap new_affine,
mlir::OpBuilder builder) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
return builder.create<mlir::AffineLoadOp>(
builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
new_affine.operands);
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
return builder.create<mlir::AffineStoreOp>(
builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
new_affine.affine_map, new_affine.operands);
} else {
CHECK(false);
}
}
void SetMemRef(mlir::Operation* op, mlir::Value memref) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
load.setMemRef(memref);
@ -161,127 +120,6 @@ void SetMemRef(mlir::Operation* op, mlir::Value memref) {
}
}
std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder) {
std::vector<mlir::AffineForOp> loops;
loops.reserve(upper_bounds.size());
for (int64_t dim : upper_bounds) {
auto loop =
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
loops.push_back(loop);
builder = loop.getBodyBuilder();
}
return loops;
}
void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
mlir::OpBuilder builder) {
CHECK(IsSimpleLoop(loop));
loop.setUpperBoundMap(mlir::AffineMap::get(
loop.getUpperBoundMap().getNumDims(),
loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
}
// Tile a loop with trip count N by `size`. For now, N has to be a multiple of
// size, but later this constraint will be removed.
//
// The major loop (with trip count N / size) stays as-is, while the minor loop
// (with trip count `size`) will take over the body of `target`, and be placed
// as the new body of `target`.
//
// `target` has to be within the same "perfectly nested loop group" as `loop`.
// See the documentation for mlir::getPerfectlyNestedLoops.
//
// Example:
// Before tiling `loop` with tile size X:
// for (loop in N)
// for (unrelated_loop in ...)
// for (target in ...)
// // pass loop into affine maps
// After:
// for (loop in N / X)
// for (unrelated_loop in ...)
// for (target in ...)
// for (tiled_loop in X)
// // rewrite all affine exprs from loop to `loop * X + tiled_loop`.
//
// Design note:
// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well
// documented about the exact tiling semantics, but the observed behavior is:
// for (i from 0 to N)
// for (unrelated_loop in ...)
// for (target in ...)
// // pass i into affine maps
// =>
// for (i from 0 to N, step = X)
// for (unrelated_loop in ...)
// for (target in ...)
// for (j from i to min(i + X, N), step = 1)
// // pass j into affine maps
//
// There are two differences between mlir::tile and TileLoop:
// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs.
// With that all index calculation is done in AffineExprs and easier to
// analyze in a single place.
// * TileLoop doesn't plan to use use max() and min() to resolve the issue when
// N % X != 0. max() and min() are not representable in AffineExprs.
// TODO(timshen): support the case where N % X != 0.
//
// TODO(timshen): consider the possibility to reuse mlir::tile's logic to
// achieve the same goal.
mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
mlir::AffineForOp target) {
CHECK(IsSimpleLoop(loop));
CHECK(IsSimpleLoop(target));
{
llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
getPerfectlyNestedLoops(all_loops, loop);
CHECK(absl::c_linear_search(all_loops, target));
}
auto builder = target.getBodyBuilder();
auto inner_loop =
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
{
auto& inner_operations = inner_loop.getBody()->getOperations();
auto& target_operations = target.getBody()->getOperations();
inner_operations.splice(inner_operations.begin(), target_operations,
target_operations.begin(),
std::prev(target_operations.end(), 2));
mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
}
for (auto& use :
llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
mlir::Operation* owner = use.getOwner();
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
unsigned new_dim = affine_map.operands.size();
affine_map.operands.push_back(inner_loop.getInductionVar());
std::vector<mlir::AffineExpr> replacements;
for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
if (affine_map.operands[i] == loop.getInductionVar()) {
replacements.push_back(builder.getAffineDimExpr(i) * size +
builder.getAffineDimExpr(new_dim));
} else {
replacements.push_back(builder.getAffineDimExpr(i));
}
}
affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
replacements, {}, affine_map.operands.size(), 0);
auto new_op =
CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
owner->replaceAllUsesWith(new_op);
owner->erase();
}
return inner_loop;
}
// Hoist operations out of `where`. [begin_op, end_op) must be the first
// operations of their parent loop, and `where` must be an ancestor of that
// parent loop.
@ -387,21 +225,6 @@ mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
}
// Sinks a segment of perfectly nested loops to the bottom. It implements this
// by rotating the loop nest by rotate_amount.
void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
int rotate_amount) {
CHECK_GE(rotate_amount, 0);
std::vector<unsigned> permutation(loops.size());
std::iota(permutation.begin(), permutation.end(), unsigned(0));
std::rotate(permutation.begin(),
permutation.begin() + loops.size() - rotate_amount,
permutation.end());
mlir::interchangeLoops(
llvm::ArrayRef<mlir::AffineForOp>(loops.begin(), loops.end()),
permutation);
}
struct InitialMlirConvAnchors {
std::vector<mlir::AffineForOp> cartesian_product_loops;
std::vector<mlir::AffineForOp> reduction_loops;

View File

@ -0,0 +1,152 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/mlir_gpu/experimental/conv_emitter/conv_emitter_transforms.h"
#include "absl/algorithm/container.h"
#include "llvm/ADT/StringRef.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/Transforms/LoopUtils.h" // from @llvm-project
#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace mlir_gpu {
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
return {load.getAffineMap(),
std::vector<mlir::Value>(load.getMapOperands().begin(),
load.getMapOperands().end())};
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
return {store.getAffineMap(),
std::vector<mlir::Value>(store.getMapOperands().begin(),
store.getMapOperands().end())};
} else {
CHECK(false);
}
}
mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
BoundAffineMap new_affine,
mlir::OpBuilder builder) {
if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
return builder.create<mlir::AffineLoadOp>(
builder.getUnknownLoc(), load.getMemRef(), new_affine.affine_map,
new_affine.operands);
} else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
return builder.create<mlir::AffineStoreOp>(
builder.getUnknownLoc(), store.getValueToStore(), store.getMemRef(),
new_affine.affine_map, new_affine.operands);
} else {
CHECK(false);
}
}
bool IsSimpleLoop(mlir::AffineForOp loop) {
return loop.getLowerBoundMap().isSingleConstant() &&
loop.getLowerBoundMap().getSingleConstantResult() == 0 &&
loop.getStep() == 1 && loop.getUpperBoundMap().getNumResults() == 1 &&
std::next(loop.region().begin()) == loop.region().end();
}
std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder) {
std::vector<mlir::AffineForOp> loops;
loops.reserve(upper_bounds.size());
for (int64_t dim : upper_bounds) {
auto loop =
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
loops.push_back(loop);
builder = loop.getBodyBuilder();
}
return loops;
}
void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
mlir::OpBuilder builder) {
CHECK(IsSimpleLoop(loop));
loop.setUpperBoundMap(mlir::AffineMap::get(
loop.getUpperBoundMap().getNumDims(),
loop.getUpperBoundMap().getNumSymbols(), {new_bound}));
}
mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
mlir::AffineForOp target) {
CHECK(IsSimpleLoop(loop));
CHECK(IsSimpleLoop(target));
{
llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
getPerfectlyNestedLoops(all_loops, loop);
CHECK(absl::c_linear_search(all_loops, target));
}
auto builder = target.getBodyBuilder();
auto inner_loop =
builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, size);
{
auto& inner_operations = inner_loop.getBody()->getOperations();
auto& target_operations = target.getBody()->getOperations();
inner_operations.splice(inner_operations.begin(), target_operations,
target_operations.begin(),
std::prev(target_operations.end(), 2));
mlir::AffineExpr length = loop.getUpperBoundMap().getResult(0);
CHECK_EQ(0, length.cast<mlir::AffineConstantExpr>().getValue() % size);
SetBoundForSimpleLoop(loop, length.ceilDiv(size), builder);
}
for (auto& use :
llvm::make_early_inc_range(loop.getInductionVar().getUses())) {
mlir::Operation* owner = use.getOwner();
BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
unsigned new_dim = affine_map.operands.size();
affine_map.operands.push_back(inner_loop.getInductionVar());
std::vector<mlir::AffineExpr> replacements;
for (int i = 0; i < affine_map.affine_map.getNumDims(); i++) {
if (affine_map.operands[i] == loop.getInductionVar()) {
replacements.push_back(builder.getAffineDimExpr(i) * size +
builder.getAffineDimExpr(new_dim));
} else {
replacements.push_back(builder.getAffineDimExpr(i));
}
}
affine_map.affine_map = affine_map.affine_map.replaceDimsAndSymbols(
replacements, {}, affine_map.operands.size(), 0);
auto new_op =
CloneWithNewAffineMap(owner, affine_map, mlir::OpBuilder(owner));
owner->replaceAllUsesWith(new_op);
owner->erase();
}
return inner_loop;
}
void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
int rotate_amount) {
CHECK_GE(rotate_amount, 0);
std::vector<unsigned> permutation(loops.size());
std::iota(permutation.begin(), permutation.end(), unsigned(0));
std::rotate(permutation.begin(),
permutation.begin() + loops.size() - rotate_amount,
permutation.end());
mlir::interchangeLoops(
llvm::ArrayRef<mlir::AffineForOp>(loops.begin(), loops.end()),
permutation);
}
} // namespace mlir_gpu
} // namespace xla

View File

@ -0,0 +1,102 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_
#include "absl/base/integral_types.h"
#include "absl/types/span.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
namespace xla {
namespace mlir_gpu {
struct BoundAffineMap {
mlir::AffineMap affine_map;
std::vector<mlir::Value> operands;
};
BoundAffineMap GetBoundAffineMapFrom(mlir::Operation* op);
mlir::Operation* CloneWithNewAffineMap(mlir::Operation* op,
BoundAffineMap new_affine,
mlir::OpBuilder builder);
bool IsSimpleLoop(mlir::AffineForOp loop);
std::vector<mlir::AffineForOp> CreateNestedSimpleLoops(
absl::Span<const int64_t> upper_bounds, mlir::OpBuilder builder);
void SetBoundForSimpleLoop(mlir::AffineForOp loop, mlir::AffineExpr new_bound,
mlir::OpBuilder builder);
// Tile a loop with trip count N by `size`. For now, N has to be a multiple of
// size, but later this constraint will be removed.
//
// The major loop (with trip count N / size) stays as-is, while the minor loop
// (with trip count `size`) will take over the body of `target`, and be placed
// as the new body of `target`.
//
// `target` has to be within the same "perfectly nested loop group" as `loop`.
// See the documentation for mlir::getPerfectlyNestedLoops.
//
// Example:
// Before tiling `loop` with tile size X:
// for (loop in N)
// for (unrelated_loop in ...)
// for (target in ...)
// // pass loop into affine maps
// After:
// for (loop in N / X)
// for (unrelated_loop in ...)
// for (target in ...)
// for (tiled_loop in X)
// // rewrite all affine exprs from loop to `loop * X + tiled_loop`.
//
// Design note:
// TileLoop is different from mlir::tile. At the moment, mlir::tile is not well
// documented about the exact tiling semantics, but the observed behavior is:
// for (i from 0 to N)
// for (unrelated_loop in ...)
// for (target in ...)
// // pass i into affine maps
// =>
// for (i from 0 to N, step = X)
// for (unrelated_loop in ...)
// for (target in ...)
// for (j from i to min(i + X, N), step = 1)
// // pass j into affine maps
//
// There are two differences between mlir::tile and TileLoop:
// * TileLoop always puts the tiling logic "stepping" logic into AffineExprs.
// With that all index calculation is done in AffineExprs and easier to
// analyze in a single place.
// * TileLoop doesn't plan to use use max() and min() to resolve the issue when
// N % X != 0. max() and min() are not representable in AffineExprs.
// TODO(timshen): support the case where N % X != 0.
//
// TODO(timshen): consider the possibility to reuse mlir::tile's logic to
// achieve the same goal.
mlir::AffineForOp TileLoop(mlir::AffineForOp loop, int64_t size,
mlir::AffineForOp target);
// Sinks a segment of perfectly nested loops to the bottom. It implements this
// by rotating the loop nest by rotate_amount.
void SinkPerfectlyNestedLoops(absl::Span<const mlir::AffineForOp> loops,
int rotate_amount);
} // namespace mlir_gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_MLIR_GPU_EXPERIMENTAL_CONV_EMITTER_CONV_EMITTER_TRANSFORMS_H_

View File

@ -320,10 +320,7 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
end = {k, std::min(i * block_size, n)};
}
if (!left_side) {
std::swap(end[0], end[1]);
}
if (transpose_a) {
if (!left_side ^ transpose_a) {
std::swap(start[0], start[1]);
std::swap(end[0], end[1]);
}
@ -337,16 +334,12 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
}
XlaOp x_update;
auto zero = Zero(builder, S32);
auto start_index = ConstantR0WithType(builder, S32, j * block_size);
std::vector<XlaOp> update_starts = {start_index, zero};
if (left_side) {
x_update =
BatchDot(inv_block, transpose_a, remainder, false, precision);
} else {
x_update =
BatchDot(remainder, false, inv_block, transpose_a, precision);
std::swap(update_starts[0], update_starts[1]);
}
if (i == 0) {

View File

@ -458,7 +458,7 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
Array2D<float> avals(spec.m, spec.m);
avals.FillRandom(1.0);
for (int i = 0; i < spec.m; ++i) {
avals(i, i) += 10;
avals(i, i) += 30;
}
std::pair<int, int> bdims = spec.left_side ? std::make_pair(spec.m, spec.n)
@ -481,13 +481,13 @@ XLA_TEST_P(TriangularSolveParametricTest, Random) {
}
ComputeAndCompareR2<float>(&builder, bvals, {a_data.get(), b_data.get()},
ErrorSpec(1e-2, 1e-2));
ErrorSpec(3e-2, 3e-2));
}
std::vector<TriangularSolveTestSpec> TriangularSolveTests() {
std::vector<TriangularSolveTestSpec> specs;
for (int m : {5, 10}) {
for (int n : {5, 10}) {
for (int m : {5, 10, 150}) {
for (int n : {5, 10, 150}) {
for (bool left_side : {false, true}) {
for (bool lower : {false, true}) {
for (TriangularSolveOptions::Transpose transpose_a :

View File

@ -2550,6 +2550,7 @@ filegroup(
"common_runtime/executor_factory.h",
"common_runtime/function_optimization_registry.h",
"common_runtime/graph_optimizer.h",
"common_runtime/graph_view.h",
"common_runtime/input_colocation_exemption_registry.h",
"common_runtime/isolate_placer_inspection_required_ops_pass.h",
"common_runtime/local_device.h",
@ -2613,6 +2614,7 @@ tf_cuda_library(
"common_runtime/function_optimization_registry.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
"common_runtime/graph_view.cc",
"common_runtime/hierarchical_tree_broadcaster.cc",
"common_runtime/input_colocation_exemption_registry.cc",
"common_runtime/inspecting_placer.cc",

View File

@ -382,7 +382,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
}
void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
if (ptr != nullptr) {
AddTraceMe("MemoryAllocation", num_bytes);
AddTraceMe("MemoryAllocation", ptr);
return ptr;
}
@ -390,7 +390,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
if (Extend(unused_alignment, rounded_bytes)) {
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
if (ptr != nullptr) {
AddTraceMe("MemoryAllocation", num_bytes);
AddTraceMe("MemoryAllocation", ptr);
return ptr;
}
}
@ -403,7 +403,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
if (MergeTimestampedChunks(rounded_bytes)) {
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
if (ptr != nullptr) {
AddTraceMe("MemoryAllocation", num_bytes);
AddTraceMe("MemoryAllocation", ptr);
return ptr;
}
}
@ -417,7 +417,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
Extend(unused_alignment, rounded_bytes)) {
ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
if (ptr != nullptr) {
AddTraceMe("MemoryAllocation", num_bytes);
AddTraceMe("MemoryAllocation", ptr);
return ptr;
}
}
@ -439,7 +439,7 @@ void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
}
void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
int64 requested_bytes) {
const void* chunk_ptr) {
// Internal users will see the memory profile with default trace level.
auto traceme_level = profiler::TraceMeLevel::kVerbose;
#ifdef PLATFORM_GOOGLE
@ -451,12 +451,17 @@ void BFCAllocator::AddTraceMe(absl::string_view traceme_name,
AllocatorStats stats = stats_;
int64 bytes_available =
memory_limit_ - stats.bytes_reserved - stats.bytes_in_use;
BFCAllocator::Chunk* chunk =
ChunkFromHandle(region_manager_.get_handle(chunk_ptr));
return absl::StrCat(traceme_name, "#allocator_name=", name_,
",bytes_reserved=", stats.bytes_reserved,
",bytes_allocated=", stats.bytes_in_use,
",bytes_available=", bytes_available,
",peak_bytes_in_use=", stats.peak_bytes_in_use,
",requested_bytes=", requested_bytes,
",requested_bytes=", chunk->requested_size,
",allocation_bytes=", chunk->size,
",addr=", reinterpret_cast<uint64>(chunk_ptr),
",tf_op=", pending_op_name, ",id=", pending_step_id,
"#");
},
@ -595,8 +600,10 @@ void BFCAllocator::DeallocateRawInternal(void* ptr) {
BFCAllocator::ChunkHandle h = region_manager_.get_handle(ptr);
CHECK(h != kInvalidChunkHandle);
int64 requested_bytes = ChunkFromHandle(h)->requested_size;
MarkFree(h);
// TraceMe needs to be added after MarkFree and before InsertFreeChunkIntoBin
// for correct memory stats.
AddTraceMe("MemoryDeallocation", ptr);
// Consider coalescing it.
if (timing_counter_) {
@ -609,8 +616,6 @@ void BFCAllocator::DeallocateRawInternal(void* ptr) {
if (VLOG_IS_ON(4)) {
LOG(INFO) << "F: " << RenderOccupancy();
}
AddTraceMe("MemoryDeallocation", -requested_bytes);
}
// Merges h1 and h2 when Chunk(h1)->next is h2 and Chunk(h2)->prev is c1.

View File

@ -116,8 +116,9 @@ class BFCAllocator : public Allocator {
TF_EXCLUSIVE_LOCKS_REQUIRED(lock_);
// Add TraceMe (in memory allocation and deallocation) for memory stats
// profiling. The requested_bytes can be negative if it's a deallocation.
void AddTraceMe(absl::string_view traceme_name, int64 requested_bytes)
// profiling. The chunk_ptr is passed to get information such as address,
// chunk size and requested_size.
void AddTraceMe(absl::string_view traceme_name, const void* chunk_ptr)
TF_EXCLUSIVE_LOCKS_REQUIRED(lock_);
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator

View File

@ -1337,10 +1337,11 @@ Status DirectSession::CreateExecutors(
device_mgr_.get(), options_.env, &options_.config, graph_def_version,
func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
/*parent=*/nullptr, custom_kernel_creator, session_metadata,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}));
Rendezvous::Factory{
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}}));
GraphOptimizer optimizer(optimizer_opts);
for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {

View File

@ -128,11 +128,11 @@ void EagerContext::ResetPFLR(const DeviceMgr* device_mgr, Env* env,
thread::ThreadPool* thread_pool,
DistributedFunctionLibraryRuntime* cluster_flr,
const CustomKernelCreator* custom_kernel_creator) {
Rendezvous::Factory rendezvous_factory =
Rendezvous::Factory rendezvous_factory{
[this](const int64 step_id, const DeviceMgr*, Rendezvous** r) {
*r = CreateRendezvous(step_id);
return Status::OK();
};
}};
if (lazy_copy_function_remote_inputs_) {
pflr_.reset(new eager::EagerProcessFunctionLibraryRuntime(
device_mgr, env, config, graph_def_version, lib_def, optimizer_options,
@ -1102,6 +1102,7 @@ Status EagerContext::UpdateRemoteMaster(
if (rendezvous_ != nullptr) rendezvous_->Unref();
rendezvous_ = r;
remote_eager_workers_ = std::move(remote_eager_workers);
pflr_->InitializeDeviceSet();
InitPrioritizedDeviceTypeList();
default_executor_.ClearError();

View File

@ -583,11 +583,11 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
if (executor.Async()) {
const DataTypeVector& output_dtypes = kernel->output_dtypes();
for (int i = 0; i < num_outputs; ++i) {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
retvals[i] = TensorHandle::CreateEmptyLocalHandle(
/* d= */ ctx.CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i),
output_dtypes[i], &ctx, &retvals[i]));
output_dtypes[i], &ctx);
}
auto node = absl::make_unique<AsyncExecuteNode>(
&ctx, op->Inputs(), op->remote_func_params(), std::move(kernel),
@ -773,18 +773,8 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
// remote device here. We just need to know that it is remote. If we need
// to copy this tensor to this process, the remote end will know the
// correct device of this handle.
Status status = TensorHandle::CreateUnshapedRemoteHandle(
id, i, remote_task, output_dtypes[i], op_device, &ctx, &retvals[i]);
if (!status.ok()) {
for (int j = 0; j < i; ++j) {
retvals[j]->PoisonRemote(
errors::Internal(
"Failed to construct unshaped remote tensor handle at index ",
i, " for op ", op->Name()),
op_device, ctx.GetContextViewId());
}
return status;
}
retvals[i] = TensorHandle::CreateUnshapedRemoteHandle(
id, i, remote_task, output_dtypes[i], op_device, &ctx);
}
if (ctx.LazyCopyFunctionRemoteInputs()) {
@ -1056,12 +1046,11 @@ Status EagerKernelExecute(
for (int i = 0; i < retvals.size(); ++i) {
if (retvals[i] == nullptr) {
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
retvals[i] = TensorHandle::CreateLocalHandle(
std::move(outputs[i]),
/* d= */ ctx->CanonicalDevice(kernel->OutputDevice(i)),
/* op_device= */ kernel->device(),
/* resource_device= */ kernel->OutputResourceDevice(i), ctx,
&retvals[i]));
/* resource_device= */ kernel->OutputResourceDevice(i), ctx);
} else {
DCHECK_EQ(kernel->device(), retvals[i]->op_device());
DCHECK_EQ(ctx->CanonicalDevice(kernel->OutputDevice(i)),
@ -1100,8 +1089,8 @@ Status LocalEagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
h->Ref();
*result = h;
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
d, dstd, h->resource_device(), h->dtype, ctx, result));
*result = TensorHandle::CreateEmptyLocalHandle(
d, dstd, h->resource_device(), h->dtype, ctx);
}
Status s;
@ -1169,9 +1158,9 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
h->Ref();
*result = h;
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateEmptyLocalHandle(
*result = TensorHandle::CreateEmptyLocalHandle(
/* d= */ d, /* op_device= */ device,
/*resource_device=*/nullptr, h->dtype, ctx, result));
/*resource_device=*/nullptr, h->dtype, ctx);
}
} else {
if (mirror) {
@ -1194,8 +1183,8 @@ Status EagerCopyToDevice(TensorHandle* h, EagerContext* ctx,
h->Ref();
*result = h;
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
recv_op_id, 0, remote_task, h->dtype, device, ctx, result));
*result = TensorHandle::CreateUnshapedRemoteHandle(
recv_op_id, 0, remote_task, h->dtype, device, ctx);
}
}

View File

@ -50,8 +50,8 @@ void EagerProcessFunctionLibraryRuntime::Run(
std::move(done));
}
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done =
ApplyCleanUpToDoneCallback(cleanup_items, done, /*rendezvous=*/nullptr);
done = ApplyCleanUpToDoneCallback(cleanup_items, done, opts.step_id,
/*rendezvous=*/nullptr);
auto get_component_args = [&args](const ComponentFunctionData& comp_data,
InternalArgs* comp_args) -> Status {

View File

@ -106,40 +106,36 @@ Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h) {
TensorHandle* TensorHandle::CreateLocalHandle(const tensorflow::Tensor& t) {
// TODO(b/136608821): Move away from nullptr
tensorflow::Tensor tensor = t;
return CreateLocalHandle(std::move(tensor),
/*d=*/nullptr,
/*op_device=*/nullptr,
/*ctx=*/nullptr, h);
/*ctx=*/nullptr);
}
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, EagerContext* ctx,
TensorHandle** h) {
return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx, h);
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device,
EagerContext* ctx) {
return CreateLocalHandle(std::move(t), d, op_device, nullptr, ctx);
}
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device,
Device* resource_device,
EagerContext* ctx, TensorHandle** h) {
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device,
Device* resource_device,
EagerContext* ctx) {
if (t.dtype() == DT_RESOURCE && t.NumElements() > 0) {
*h = new TensorHandle(std::move(t), d, op_device, ctx);
return new TensorHandle(std::move(t), d, op_device, ctx);
} else {
*h = new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
return new TensorHandle(std::move(t), d, op_device, resource_device, ctx);
}
return Status::OK();
}
Status TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h) {
*h = new TensorHandle(std::move(t), d, ctx);
return Status::OK();
TensorHandle* TensorHandle::CreateLocalHandle(tensorflow::Tensor&& t,
CustomDevice* d,
EagerContext* ctx) {
return new TensorHandle(std::move(t), d, ctx);
}
TensorHandle::TensorHandle(tensorflow::Tensor&& t, Device* d, Device* op_device,
@ -190,13 +186,11 @@ TensorHandle::TensorHandle(tensorflow::Tensor&& t, CustomDevice* d,
<< " tensor: " << t.DeviceSafeDebugString();
}
Status TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device,
DataType dtype, EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(d, op_device, resource_device, dtype, ctx);
return Status::OK();
TensorHandle* TensorHandle::CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device,
DataType dtype,
EagerContext* ctx) {
return new TensorHandle(d, op_device, resource_device, dtype, ctx);
}
TensorHandle::TensorHandle(Device* d, Device* op_device,
@ -214,14 +208,10 @@ TensorHandle::TensorHandle(Device* d, Device* op_device,
}
#if !defined(IS_MOBILE_PLATFORM)
Status TensorHandle::CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* d,
EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
return Status::OK();
TensorHandle* TensorHandle::CreateUnshapedRemoteHandle(
int64 op_id, int32 output_num, const string& remote_task, DataType dtype,
Device* d, EagerContext* ctx) {
return new TensorHandle(op_id, output_num, remote_task, dtype, d, ctx);
}
TensorHandle::TensorHandle(int64 op_id, int32 output_num,
@ -239,13 +229,11 @@ TensorHandle::TensorHandle(int64 op_id, int32 output_num,
<< " device: " << VariantDeviceDebugString(device_);
}
Status TensorHandle::CreateLazyRemoteHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx,
TensorHandle** h) {
*h = new TensorHandle(op_id, output_num, dtype, d, ctx);
return Status::OK();
TensorHandle* TensorHandle::CreateLazyRemoteHandle(int64 op_id,
int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx) {
return new TensorHandle(op_id, output_num, dtype, d, ctx);
}
TensorHandle::TensorHandle(int64 op_id, int32 output_num, DataType dtype,

View File

@ -77,27 +77,27 @@ class TensorHandle : public core::RefCounted {
public:
// TensorHandle with no assigned device
static Status CreateLocalHandle(const tensorflow::Tensor& t,
TensorHandle** h);
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, EagerContext* ctx,
TensorHandle** h);
static Status CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, Device* resource_device,
EagerContext* ctx, TensorHandle** h);
static Status CreateLocalHandle(tensorflow::Tensor&& t, CustomDevice* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device, DataType dtype,
EagerContext* ctx, TensorHandle** h);
static TensorHandle* CreateLocalHandle(const tensorflow::Tensor& t);
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device, EagerContext* ctx);
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t, Device* d,
Device* op_device,
Device* resource_device,
EagerContext* ctx);
static TensorHandle* CreateLocalHandle(tensorflow::Tensor&& t,
CustomDevice* d, EagerContext* ctx);
static TensorHandle* CreateEmptyLocalHandle(Device* d, Device* op_device,
Device* resource_device,
DataType dtype,
EagerContext* ctx);
#if !defined(IS_MOBILE_PLATFORM)
static Status CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h);
static Status CreateLazyRemoteHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx, TensorHandle** h);
static TensorHandle* CreateUnshapedRemoteHandle(int64 op_id, int32 output_num,
const string& remote_task,
DataType dtype, Device* d,
EagerContext* ctx);
static TensorHandle* CreateLazyRemoteHandle(int64 op_id, int32 output_num,
DataType dtype, Device* d,
EagerContext* ctx);
#endif // IS_MOBILE_PLATFORM
~TensorHandle() override { DVLOG(3) << "Deleting TensorHandle " << this; }

View File

@ -38,15 +38,10 @@ TEST(TensorHandle_ShapeTest, AsyncShape) {
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
tensorflow::ContextMirroringPolicy::MIRRORING_NONE, false, false,
&device_mgr, false, nullptr, nullptr, nullptr);
TensorHandle* sync_th;
EXPECT_TRUE(TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr,
ctx, &sync_th)
.ok());
TensorHandle* async_th;
EXPECT_TRUE(TensorHandle::CreateEmptyLocalHandle(nullptr, nullptr, nullptr,
DataType::DT_UINT16, ctx,
&async_th)
.ok());
TensorHandle* sync_th =
TensorHandle::CreateLocalHandle(std::move(t), nullptr, nullptr, ctx);
TensorHandle* async_th = TensorHandle::CreateEmptyLocalHandle(
nullptr, nullptr, nullptr, DataType::DT_UINT16, ctx);
EXPECT_TRUE(async_th->CopyInferenceShape(sync_th).ok());

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/common_runtime/executor_factory.h"
#include "tensorflow/core/common_runtime/graph_view.h"
#include "tensorflow/core/common_runtime/metrics.h"
#include "tensorflow/core/common_runtime/pending_counts.h"
#include "tensorflow/core/common_runtime/renamed_device.h"
@ -124,19 +125,6 @@ void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
} // namespace nodestats
class ExecutorImpl;
class GraphView;
struct EdgeInfo {
int dst_id;
int output_slot : 31;
// true if this is the last info for output_slot in the EdgeInfo list.
bool is_last : 1;
int input_slot;
};
struct ControlEdgeInfo {
int dst_id;
};
// Time the execution of kernels (in CPU cycles). Used to dynamically identify
// inexpensive kernels which can be dispatched inline.
@ -148,196 +136,9 @@ struct KernelTimer {
}
};
// Compact structure representing a graph node and its associated kernel.
//
// Each NodeItem is an element of exactly one GraphView.
struct NodeItem {
NodeItem() {}
// The index of this node's item in its GraphView.
int node_id = -1;
// Cached attributes of this node for fast lookup.
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
bool is_merge : 1; // True iff IsMerge(node)
bool is_enter : 1; // True iff IsEnter(node)
bool is_constant_enter : 1; // True iff IsEnter(node) and
// node->GetAttr("is_constant") == true.
bool is_exit : 1; // True iff IsExit(node)
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
bool is_source : 1; // True iff IsSource(node)
// True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
bool is_enter_exit_or_next_iter : 1;
bool is_transfer_node : 1; // True iff IsTransferNode(node)
bool is_initialization_op : 1; // True iff IsInitializationOp(node)
bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
bool is_next_iteration : 1; // True iff IsNextIteration(node)
bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
bool
is_any_consumer_merge_or_control_trigger : 1; // True iff the destination
// of any output edge is a
// merge or control trigger
// node.
// The kernel for this node.
OpKernel* kernel = nullptr;
// If the kernel is a Const op, this containts points to the constant tensor.
const Tensor* const_tensor = nullptr;
// Cached values of node->num_inputs() and node->num_outputs(), to
// avoid levels of indirection.
int num_inputs;
int num_outputs;
// ExecutorImpl::tensors_[input_start] is the 1st positional input
// for this node.
int input_start = 0;
// Number of output edges, excluding control edges.
int32 num_output_edges;
// Number of output control edges.
int32 num_output_control_edges;
// If non-null, contains an array of num_outputs bools, where the ith bool
// is true if and only if the ith output is consumed by another node.
std::unique_ptr<bool[]> outputs_required;
gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
num_output_edges);
}
gtl::ArraySlice<EdgeInfo> output_edges() const {
return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
}
gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
num_output_control_edges);
}
DataType input_type(int i) const {
DCHECK_LT(i, num_inputs);
return static_cast<DataType>(input_type_base()[i]);
}
DataType output_type(int i) const {
DCHECK_LT(i, num_outputs);
return static_cast<DataType>(output_type_base()[i]);
}
// Return array of per-output allocator attributes.
const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
// Return array of expected input index from which each output should
// be forwarded:
// kNeverForward (-2) for DO NOT FORWARD (must allocate).
// kNoReservation (-1) for no expected forwarding.
// 0... for forward from that input.
const int* forward_from() const { return forward_from_base(); }
string DebugString() const {
string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id);
if (is_source) {
strings::StrAppend(&ret, " source}");
} else {
strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}");
}
return ret;
}
private:
friend class GraphView;
// Variable length section starts immediately after *this
// (uint8 is enough for DataType).
// EdgeInfo out_edges[num_out_edges];
// AllocatorAttributes output_attr[num_outputs];
// int forward_from[num_outputs];
// uint8 input_type[num_inputs];
// uint8 output_type[num_outputs];
// Return pointer to variable length section.
char* var() const {
return const_cast<char*>(reinterpret_cast<const char*>(this) +
sizeof(NodeItem));
}
EdgeInfo* output_edge_base() const {
return reinterpret_cast<EdgeInfo*>(var());
}
ControlEdgeInfo* output_control_edge_base() const {
return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
num_output_edges);
}
AllocatorAttributes* output_attr_base() const {
return reinterpret_cast<AllocatorAttributes*>(
var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) * num_output_control_edges);
}
int* forward_from_base() const {
return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) *
num_output_control_edges +
sizeof(AllocatorAttributes) * num_outputs);
}
uint8* input_type_base() const {
return reinterpret_cast<uint8*>(
var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) * num_output_control_edges +
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
}
uint8* output_type_base() const {
return reinterpret_cast<uint8*>(
var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) * num_output_control_edges +
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
sizeof(uint8) * num_inputs);
}
TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
};
typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
// Immutable view of a Graph organized for efficient execution.
class GraphView {
public:
GraphView() : space_(nullptr) {}
~GraphView();
Status Initialize(const Graph* g);
Status SetAllocAttrs(const Graph* g, const Device* device);
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
NodeItem* node(int32 id) const {
DCHECK_GE(id, 0);
DCHECK_LT(id, num_nodes_);
uint32 offset = node_offsets_[id];
return ((offset == kuint32max)
? nullptr
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
}
int32 num_nodes() const { return num_nodes_; }
private:
char* InitializeNode(char* ptr, const Node* n);
size_t NodeItemBytes(const Node* n);
int32 num_nodes_ = 0;
uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
// node_offsets_[id] holds the byte offset for node w/ "id" in space_
char* space_; // NodeItem objects are allocated here
TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
};
class ExecutorImpl : public Executor {
public:
explicit ExecutorImpl(const LocalExecutorParams& p) : params_(p), gview_() {
@ -499,237 +300,6 @@ class ExecutorImpl : public Executor {
TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
};
// Infer memory allocation attributes of a node n's output,
// based on its use node dst. Note that dst might not be directly
// connected to n by a single edge, but might be a downstream
// consumer of n's output by reference. *attr is updated with any
// necessary attributes.
Status InferAllocAttr(const Node* n, const Node* dst,
const DeviceNameUtils::ParsedName& local_dev_name,
AllocatorAttributes* attr);
GraphView::~GraphView() {
static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
"Update code if AllocatorAttributes gains a destructor");
static_assert(std::is_trivially_destructible<EdgeInfo>::value,
"Update code if EdgeInfo gains a destructor");
for (int i = 0; i < num_nodes_; i++) {
NodeItem* n = node(i);
if (n != nullptr) {
n->NodeItem::~NodeItem();
// Memory for "n" itself is held in space_ & gets cleaned up below
}
}
delete[] node_offsets_;
delete[] space_;
}
typedef std::tuple<int32, int32> OutputAndControlEdges;
static OutputAndControlEdges CountOutputEdges(const Node* n) {
DCHECK_LE(n->out_edges().size(), kint32max);
int32 num_output_edges = 0;
int32 num_output_control_edges = 0;
for (auto e : n->out_edges()) {
if (IsSink(e->dst())) continue;
if (e->IsControlEdge()) {
++num_output_control_edges;
} else {
++num_output_edges;
}
}
return OutputAndControlEdges(num_output_edges, num_output_control_edges);
}
size_t GraphView::NodeItemBytes(const Node* n) {
int32 num_output_edges;
int32 num_output_control_edges;
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
const int num_inputs = n->num_inputs();
const int num_outputs = n->num_outputs();
// Compute number of bytes needed for NodeItem and variable length data.
// We do not subtract sizeof(var) since num_inputs/num_outputs might
// both be zero.
const size_t raw_bytes =
sizeof(NodeItem) // Fixed
+ num_output_edges * sizeof(EdgeInfo) // output_edges[...]
+ num_output_control_edges * //
sizeof(ControlEdgeInfo) // output_control_edges[...]
+ num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
+ num_outputs * sizeof(int) // forward_from[num_outputs]
+ num_inputs * sizeof(uint8) // input_type[num_inputs]
+ num_outputs * sizeof(uint8); // output_type[num_outputs]
static constexpr size_t kItemAlignment = sizeof(NodeItem*);
static_assert(kItemAlignment % alignof(NodeItem) == 0,
"NodeItem must be aligned with kItemAlignment");
static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
"EdgeInfo must be aligned with kItemAlignment");
static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0,
"ControlEdgeInfo must be aligned with kItemAlignment");
static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
"AllocatorAttributes must be aligned with kItemAlignment");
static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
"NodeItem must be aligned with EdgeInfo");
static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
"NodeItem must be aligned with AllocatorAttributes");
static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
"EdgeInfo must be aligned with AllocatorAttributes");
const size_t bytes =
((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
return bytes;
}
char* GraphView::InitializeNode(char* ptr, const Node* n) {
const int id = n->id();
CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor
const size_t bytes = NodeItemBytes(n);
constexpr size_t kItemAlignment = sizeof(NodeItem*);
CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
// We store a 32-bit offset relative to the beginning of space_, so that we
// only need an array of 32-bit values to map from node id to the NodeItem*,
// (versus 64 bits on most machines if we just stored an array of NodeItem*
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
// values as "int" vs "size_t" in CHECK_LE.
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
const uint32 offset = static_cast<uint32>(ptr - space_);
node_offsets_[id] = offset;
ptr += bytes;
int32 num_output_edges;
int32 num_output_control_edges;
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
const int num_inputs = n->num_inputs();
const int num_outputs = n->num_outputs();
new (item) NodeItem();
item->num_inputs = num_inputs;
item->num_outputs = num_outputs;
item->num_output_edges = num_output_edges;
item->num_output_control_edges = num_output_control_edges;
// Fill output edges.
// Keep track of the last EdgeInfo in the EdgeInfo array that references
// a given output slot. For all but the last, we need to do a copy of the
// Tensor when propagating results downstream in the graph, but for the
// last one, we can just do a move of the Tensor object to propagate it.
gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
EdgeInfo* dst_edge = item->output_edge_base();
for (auto e : n->out_edges()) {
if (e->IsControlEdge()) continue;
dst_edge->dst_id = e->dst()->id();
CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
dst_edge->output_slot = e->src_output();
dst_edge->is_last = false;
const int output_slot = dst_edge->output_slot;
if (output_slot >= 0) {
last_indices[output_slot] = dst_edge;
}
// NOTE: The `input_slot` will be rewritten to the frame-wide offset later
// in `ExecutorImpl::Initialize()`.
dst_edge->input_slot = e->dst_input();
dst_edge++;
}
for (EdgeInfo* edge_info : last_indices) {
if (edge_info != nullptr) {
edge_info->is_last = true;
}
}
ControlEdgeInfo* dst_control_edge = item->output_control_edge_base();
for (auto e : n->out_edges()) {
if (!e->IsControlEdge() || IsSink(e->dst())) continue;
dst_control_edge->dst_id = e->dst()->id();
dst_control_edge++;
}
AllocatorAttributes* output_attrs = item->output_attr_base();
for (int i = 0; i < num_outputs; i++) {
new (&output_attrs[i]) AllocatorAttributes();
}
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
uint8* input_types = item->input_type_base();
for (int i = 0; i < num_inputs; i++) {
input_types[i] = static_cast<uint8>(n->input_type(i));
DCHECK_EQ(item->input_type(i), n->input_type(i));
}
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
{
std::vector<int> forward_input;
Status fwd_status =
GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
std::vector<int> scoped_allocator_attrs;
Status sa_status =
GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
int* forward_from = item->forward_from_base();
uint8* output_types = item->output_type_base();
for (int i = 0; i < num_outputs; ++i) {
output_types[i] = static_cast<uint8>(n->output_type(i));
DCHECK_EQ(item->output_type(i), n->output_type(i));
forward_from[i] = OpKernelContext::Params::kNoReservation;
if (sa_status.ok()) {
for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
if (scoped_allocator_attrs[j] == i) {
// This output slot must be explicitly allocated from a
// ScopedAllocator.
forward_from[i] = OpKernelContext::Params::kNeverForward;
DCHECK_EQ(output_attrs[i].scope_id, 0);
output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
}
}
}
if (fwd_status.ok() &&
forward_from[i] == OpKernelContext::Params::kNoReservation) {
DCHECK_EQ(forward_input.size() % 2, 0);
for (int j = 0; j < forward_input.size(); j += 2) {
if (forward_input[j + 1] == i) {
DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
forward_from[i] = forward_input[j];
break;
}
}
}
}
}
return ptr;
}
Status GraphView::Initialize(const Graph* g) {
CHECK(node_offsets_ == nullptr);
const int num_nodes = g->num_node_ids();
num_nodes_ = num_nodes;
size_t total_bytes = 0;
for (const Node* n : g->nodes()) {
if (n->out_edges().size() > kint32max) {
return errors::InvalidArgument(
"The executor cannot handle nodes with more than ", kint32max,
" output edges. Node ", n->name(), " had ", n->out_edges().size(),
" output edges.");
}
total_bytes += NodeItemBytes(n);
}
node_offsets_ = new uint32[num_nodes];
for (int i = 0; i < num_nodes; i++) {
node_offsets_[i] = kuint32max;
}
space_ = new char[total_bytes]; // NodeItem objects are allocated here
char* ptr = space_;
for (const Node* n : g->nodes()) {
ptr = InitializeNode(ptr, n);
}
CHECK_EQ(ptr, space_ + total_bytes);
return Status::OK();
}
void GetMaxPendingCounts(const Node* n, size_t* max_pending,
size_t* max_dead_count) {
const size_t num_in_edges = n->in_edges().size();
@ -908,154 +478,6 @@ bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
return false;
}
void GraphView::SetScopedAllocatorAttrs(
const std::vector<const Node*>& sa_nodes) {
for (const Node* sa : sa_nodes) {
NodeItem* sa_item = node(sa->id());
AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
// Control edges out of the ScopedAllocator should be use instances, but may
// include a few other nodes.
for (const auto& e : sa->out_edges()) {
if (IsSink(e->dst()) || !e->IsControlEdge()) {
continue;
}
Node* use_node = e->dst();
NodeItem* item = node(use_node->id());
AllocatorAttributes* use_attrs = item->output_attr_base();
std::vector<int> scoped_allocator_attrs;
Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
&scoped_allocator_attrs);
if (!s.ok()) {
VLOG(2) << "Failed to find expected ScopedAllocator attr on "
<< use_node->name();
continue;
}
// There can be more than one output using ScopedAllocation, but this
// analysis assumes they use the same ScopedAllocator.
for (const auto& e : use_node->out_edges()) {
if (IsSink(e->dst()) || !e->IsControlEdge()) {
AllocatorAttributes attr;
if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
e->src_output(), &attr)) {
// Set the scope_id on this use instance node.
(use_attrs + e->src_output())->Merge(attr);
// Propagate the other attributes of this node back to the SA node.
attr = *(use_attrs + e->src_output());
attr.scope_id = 0;
sa_attrs->Merge(attr);
}
}
}
}
}
}
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
Status s;
DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
std::vector<const Node*> scoped_allocator_instances;
for (const Node* n : g->nodes()) {
NodeItem* item = node(n->id());
AllocatorAttributes* attrs = item->output_attr_base();
if (IsScopedAllocator(n)) {
scoped_allocator_instances.push_back(n);
}
// Examine the out edges of each node looking for special use
// cases that may affect memory allocation attributes.
for (const auto& e : n->out_edges()) {
if (!e->IsControlEdge()) {
AllocatorAttributes attr;
s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
if (!s.ok()) return s;
if (attr.value != 0 || attr.scope_id != 0) {
attrs[e->src_output()].Merge(attr);
}
}
}
for (int out = 0; out < n->num_outputs(); out++) {
const OpKernel* op_kernel = item->kernel;
DCHECK_LT(out, op_kernel->output_memory_types().size());
bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
if (on_host) {
AllocatorAttributes h;
h.set_on_host(on_host);
attrs[out].Merge(h);
}
}
}
SetScopedAllocatorAttrs(scoped_allocator_instances);
return s;
}
Status InferAllocAttr(const Node* n, const Node* dst,
const DeviceNameUtils::ParsedName& local_dev_name,
AllocatorAttributes* attr) {
Status s;
// Note that it's possible for *n to be a Recv and *dst to be a Send,
// so these two cases are not mutually exclusive.
if (IsRecv(n)) {
string src_name;
s = GetNodeAttr(n->attrs(), "send_device", &src_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_src_name;
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
n->name());
return s;
}
if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
// Value is going to be the sink of an RPC.
attr->set_nic_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
} else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
parsed_src_name.type != "CPU") {
// Value is going to be the sink of a local DMA from GPU to CPU (or
// other types of accelerators).
attr->set_gpu_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
} else {
VLOG(2) << "default alloc case local type " << local_dev_name.type
<< " remote type " << parsed_src_name.type;
}
}
if (IsSend(dst)) {
string dst_name;
s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_dst_name;
if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
n->name());
return s;
}
if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
// Value is going to be the source of an RPC.
attr->set_nic_compatible(true);
VLOG(2) << "node " << n->name() << " is the source of an RPC out";
} else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
parsed_dst_name.type != "CPU") {
// Value is going to be the source of a local DMA from CPU to GPU (or
// other types of accelerators).
// Note that this does not cover the case where the allocation of the
// output tensor is not generated by the src: n.
attr->set_gpu_compatible(true);
VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
} else {
VLOG(2) << "default alloc case local type " << local_dev_name.type
<< " remote type " << parsed_dst_name.type;
}
}
if (n->IsCollective()) {
// We'll make the sweeping assumption that any collective op is going
// to be involved in network i/o.
attr->set_nic_compatible(true);
}
return s;
}
// The state associated with one invocation of ExecutorImpl::Run.
// ExecutorState dispatches nodes when they become ready and keeps
// track of how many predecessors of a node have not done (pending_).

View File

@ -164,10 +164,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, /*thread_pool=*/nullptr,
/*parent=*/nullptr, /*custom_kernel_creator=*/nullptr,
/*session_metadata=*/nullptr,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}));
Rendezvous::Factory{
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}}));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
flr1_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:1");
flr2_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:2");

View File

@ -68,10 +68,11 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, default_thread_pool,
/*parent=*/nullptr, /*custom_kernel_creator=*/nullptr,
/*session_metadata=*/nullptr,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}));
Rendezvous::Factory{
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}}));
flr0_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
}

View File

@ -0,0 +1,442 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/graph_view.h"
#include <atomic>
#include <deque>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/edgeset.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
string NodeItem::DebugString() const {
string ret = strings::StrCat("{name:'", kernel->name(), "' id:", node_id);
if (is_source) {
strings::StrAppend(&ret, " source}");
} else {
strings::StrAppend(&ret, " def:{", SummarizeNodeDef(kernel->def()), "}}");
}
return ret;
}
GraphView::~GraphView() {
static_assert(std::is_trivially_destructible<AllocatorAttributes>::value,
"Update code if AllocatorAttributes gains a destructor");
static_assert(std::is_trivially_destructible<EdgeInfo>::value,
"Update code if EdgeInfo gains a destructor");
for (int i = 0; i < num_nodes_; i++) {
NodeItem* n = node(i);
if (n != nullptr) {
n->NodeItem::~NodeItem();
// Memory for "n" itself is held in space_ & gets cleaned up below
}
}
delete[] node_offsets_;
delete[] space_;
}
namespace {
typedef std::tuple<int32, int32> OutputAndControlEdges;
OutputAndControlEdges CountOutputEdges(const Node* n) {
DCHECK_LE(n->out_edges().size(), kint32max);
int32 num_output_edges = 0;
int32 num_output_control_edges = 0;
for (auto e : n->out_edges()) {
if (IsSink(e->dst())) continue;
if (e->IsControlEdge()) {
++num_output_control_edges;
} else {
++num_output_edges;
}
}
return OutputAndControlEdges(num_output_edges, num_output_control_edges);
}
} // namespace
size_t GraphView::NodeItemBytes(const Node* n) {
int32 num_output_edges;
int32 num_output_control_edges;
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
const int num_inputs = n->num_inputs();
const int num_outputs = n->num_outputs();
// Compute number of bytes needed for NodeItem and variable length data.
// We do not subtract sizeof(var) since num_inputs/num_outputs might
// both be zero.
const size_t raw_bytes =
sizeof(NodeItem) // Fixed
+ num_output_edges * sizeof(EdgeInfo) // output_edges[...]
+ num_output_control_edges * //
sizeof(ControlEdgeInfo) // output_control_edges[...]
+ num_outputs * sizeof(AllocatorAttributes) // output_attr[...]
+ num_outputs * sizeof(int) // forward_from[num_outputs]
+ num_inputs * sizeof(uint8) // input_type[num_inputs]
+ num_outputs * sizeof(uint8); // output_type[num_outputs]
static constexpr size_t kItemAlignment = sizeof(NodeItem*);
static_assert(kItemAlignment % alignof(NodeItem) == 0,
"NodeItem must be aligned with kItemAlignment");
static_assert(kItemAlignment % alignof(EdgeInfo) == 0,
"EdgeInfo must be aligned with kItemAlignment");
static_assert(kItemAlignment % alignof(ControlEdgeInfo) == 0,
"ControlEdgeInfo must be aligned with kItemAlignment");
static_assert(kItemAlignment % alignof(AllocatorAttributes) == 0,
"AllocatorAttributes must be aligned with kItemAlignment");
static_assert(sizeof(NodeItem) % alignof(EdgeInfo) == 0,
"NodeItem must be aligned with EdgeInfo");
static_assert(sizeof(NodeItem) % alignof(AllocatorAttributes) == 0,
"NodeItem must be aligned with AllocatorAttributes");
static_assert(sizeof(EdgeInfo) % alignof(AllocatorAttributes) == 0,
"EdgeInfo must be aligned with AllocatorAttributes");
const size_t bytes =
((raw_bytes + kItemAlignment - 1) / kItemAlignment) * kItemAlignment;
return bytes;
}
char* GraphView::InitializeNode(char* ptr, const Node* n) {
const int id = n->id();
CHECK(node_offsets_[id] == kuint32max); // Initial value in constructor
const size_t bytes = NodeItemBytes(n);
constexpr size_t kItemAlignment = sizeof(NodeItem*);
CHECK_EQ(reinterpret_cast<uintptr_t>(ptr) % kItemAlignment, 0);
NodeItem* item = reinterpret_cast<NodeItem*>(ptr);
// We store a 32-bit offset relative to the beginning of space_, so that we
// only need an array of 32-bit values to map from node id to the NodeItem*,
// (versus 64 bits on most machines if we just stored an array of NodeItem*
// pointers). Casting to int64 is needed on 32bit CPU to avoid comparing
// values as "int" vs "size_t" in CHECK_LE.
CHECK_LE(static_cast<int64>(ptr - space_), kuint32max);
const uint32 offset = static_cast<uint32>(ptr - space_);
node_offsets_[id] = offset;
ptr += bytes;
int32 num_output_edges;
int32 num_output_control_edges;
std::tie(num_output_edges, num_output_control_edges) = CountOutputEdges(n);
const int num_inputs = n->num_inputs();
const int num_outputs = n->num_outputs();
new (item) NodeItem();
item->num_inputs = num_inputs;
item->num_outputs = num_outputs;
item->num_output_edges = num_output_edges;
item->num_output_control_edges = num_output_control_edges;
// Fill output edges.
// Keep track of the last EdgeInfo in the EdgeInfo array that references
// a given output slot. For all but the last, we need to do a copy of the
// Tensor when propagating results downstream in the graph, but for the
// last one, we can just do a move of the Tensor object to propagate it.
gtl::InlinedVector<EdgeInfo*, 4> last_indices(num_outputs, nullptr);
EdgeInfo* dst_edge = item->output_edge_base();
for (auto e : n->out_edges()) {
if (e->IsControlEdge()) continue;
dst_edge->dst_id = e->dst()->id();
CHECK_LE(e->src_output(), 0x3FFFFFFF); // Must fit in 31 bits
dst_edge->output_slot = e->src_output();
dst_edge->is_last = false;
const int output_slot = dst_edge->output_slot;
if (output_slot >= 0) {
last_indices[output_slot] = dst_edge;
}
// NOTE: The `input_slot` will be rewritten to the frame-wide offset later
// in `ExecutorImpl::Initialize()`.
dst_edge->input_slot = e->dst_input();
dst_edge++;
}
for (EdgeInfo* edge_info : last_indices) {
if (edge_info != nullptr) {
edge_info->is_last = true;
}
}
ControlEdgeInfo* dst_control_edge = item->output_control_edge_base();
for (auto e : n->out_edges()) {
if (!e->IsControlEdge() || IsSink(e->dst())) continue;
dst_control_edge->dst_id = e->dst()->id();
dst_control_edge++;
}
AllocatorAttributes* output_attrs = item->output_attr_base();
for (int i = 0; i < num_outputs; i++) {
new (&output_attrs[i]) AllocatorAttributes();
}
DCHECK_LT(DataType_MAX, 255); // Must fit in uint8
uint8* input_types = item->input_type_base();
for (int i = 0; i < num_inputs; i++) {
input_types[i] = static_cast<uint8>(n->input_type(i));
DCHECK_EQ(item->input_type(i), n->input_type(i));
}
// Check ScopedAllocatorAttrs and forward_from. Also assign output_types.
{
std::vector<int> forward_input;
Status fwd_status =
GetNodeAttr(n->attrs(), "_forward_input", &forward_input);
std::vector<int> scoped_allocator_attrs;
Status sa_status =
GetNodeAttr(n->attrs(), "_scoped_allocator", &scoped_allocator_attrs);
int* forward_from = item->forward_from_base();
uint8* output_types = item->output_type_base();
for (int i = 0; i < num_outputs; ++i) {
output_types[i] = static_cast<uint8>(n->output_type(i));
DCHECK_EQ(item->output_type(i), n->output_type(i));
forward_from[i] = OpKernelContext::Params::kNoReservation;
if (sa_status.ok()) {
for (int j = 0; j < scoped_allocator_attrs.size(); j += 2) {
if (scoped_allocator_attrs[j] == i) {
// This output slot must be explicitly allocated from a
// ScopedAllocator.
forward_from[i] = OpKernelContext::Params::kNeverForward;
DCHECK_EQ(output_attrs[i].scope_id, 0);
output_attrs[i].scope_id = scoped_allocator_attrs[j + 1];
}
}
}
if (fwd_status.ok() &&
forward_from[i] == OpKernelContext::Params::kNoReservation) {
DCHECK_EQ(forward_input.size() % 2, 0);
for (int j = 0; j < forward_input.size(); j += 2) {
if (forward_input[j + 1] == i) {
DCHECK_EQ(forward_from[i], OpKernelContext::Params::kNoReservation);
forward_from[i] = forward_input[j];
break;
}
}
}
}
}
return ptr;
}
Status GraphView::Initialize(const Graph* g) {
CHECK(node_offsets_ == nullptr);
const int num_nodes = g->num_node_ids();
num_nodes_ = num_nodes;
size_t total_bytes = 0;
for (const Node* n : g->nodes()) {
if (n->out_edges().size() > kint32max) {
return errors::InvalidArgument(
"The executor cannot handle nodes with more than ", kint32max,
" output edges. Node ", n->name(), " had ", n->out_edges().size(),
" output edges.");
}
total_bytes += NodeItemBytes(n);
}
node_offsets_ = new uint32[num_nodes];
for (int i = 0; i < num_nodes; i++) {
node_offsets_[i] = kuint32max;
}
space_ = new char[total_bytes]; // NodeItem objects are allocated here
char* ptr = space_;
for (const Node* n : g->nodes()) {
ptr = InitializeNode(ptr, n);
}
CHECK_EQ(ptr, space_ + total_bytes);
return Status::OK();
}
namespace {
// If a Node has been marked to use a ScopedAllocator x for output i, then
// sc_attr will contain the subsequence (i, x) at an even offset. This function
// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
// only allow one ScopedAllocator use per Node.
bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
int output_index,
AllocatorAttributes* alloc_attr) {
DCHECK_LE(2, sc_attr.size());
for (int i = 0; i < sc_attr.size(); i += 2) {
if (sc_attr[i] == output_index) {
CHECK_EQ(alloc_attr->scope_id, 0);
alloc_attr->scope_id = sc_attr[i + 1];
return true;
}
}
return false;
}
} // namespace
void GraphView::SetScopedAllocatorAttrs(
const std::vector<const Node*>& sa_nodes) {
for (const Node* sa : sa_nodes) {
NodeItem* sa_item = node(sa->id());
AllocatorAttributes* sa_attrs = sa_item->output_attr_base();
// Control edges out of the ScopedAllocator should be use instances, but may
// include a few other nodes.
for (const auto& e : sa->out_edges()) {
if (IsSink(e->dst()) || !e->IsControlEdge()) {
continue;
}
Node* use_node = e->dst();
NodeItem* item = node(use_node->id());
AllocatorAttributes* use_attrs = item->output_attr_base();
std::vector<int> scoped_allocator_attrs;
Status s = GetNodeAttr(use_node->attrs(), "_scoped_allocator",
&scoped_allocator_attrs);
if (!s.ok()) {
VLOG(2) << "Failed to find expected ScopedAllocator attr on "
<< use_node->name();
continue;
}
// There can be more than one output using ScopedAllocation, but this
// analysis assumes they use the same ScopedAllocator.
for (const auto& e : use_node->out_edges()) {
if (IsSink(e->dst()) || !e->IsControlEdge()) {
AllocatorAttributes attr;
if (ExtractScopedAllocatorAttr(scoped_allocator_attrs,
e->src_output(), &attr)) {
// Set the scope_id on this use instance node.
(use_attrs + e->src_output())->Merge(attr);
// Propagate the other attributes of this node back to the SA node.
attr = *(use_attrs + e->src_output());
attr.scope_id = 0;
sa_attrs->Merge(attr);
}
}
}
}
}
}
namespace {
Status InferAllocAttr(const Node* n, const Node* dst,
const DeviceNameUtils::ParsedName& local_dev_name,
AllocatorAttributes* attr) {
Status s;
// Note that it's possible for *n to be a Recv and *dst to be a Send,
// so these two cases are not mutually exclusive.
if (IsRecv(n)) {
string src_name;
s = GetNodeAttr(n->attrs(), "send_device", &src_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_src_name;
if (!DeviceNameUtils::ParseFullName(src_name, &parsed_src_name)) {
s = errors::Internal("Bad send_device attr '", src_name, "' in node ",
n->name());
return s;
}
if (!DeviceNameUtils::IsSameAddressSpace(parsed_src_name, local_dev_name)) {
// Value is going to be the sink of an RPC.
attr->set_nic_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of an RPC in";
} else if ((local_dev_name.type == "CPU" || n->IsHostRecv()) &&
parsed_src_name.type != "CPU") {
// Value is going to be the sink of a local DMA from GPU to CPU (or
// other types of accelerators).
attr->set_gpu_compatible(true);
VLOG(2) << "node " << n->name() << " is the sink of a gpu->cpu copy";
} else {
VLOG(2) << "default alloc case local type " << local_dev_name.type
<< " remote type " << parsed_src_name.type;
}
}
if (IsSend(dst)) {
string dst_name;
s = GetNodeAttr(dst->attrs(), "recv_device", &dst_name);
if (!s.ok()) return s;
DeviceNameUtils::ParsedName parsed_dst_name;
if (!DeviceNameUtils::ParseFullName(dst_name, &parsed_dst_name)) {
s = errors::Internal("Bad recv_device attr '", dst_name, "' in node ",
n->name());
return s;
}
if (!DeviceNameUtils::IsSameAddressSpace(parsed_dst_name, local_dev_name)) {
// Value is going to be the source of an RPC.
attr->set_nic_compatible(true);
VLOG(2) << "node " << n->name() << " is the source of an RPC out";
} else if ((local_dev_name.type == "CPU" || dst->IsHostSend()) &&
parsed_dst_name.type != "CPU") {
// Value is going to be the source of a local DMA from CPU to GPU (or
// other types of accelerators).
// Note that this does not cover the case where the allocation of the
// output tensor is not generated by the src: n.
attr->set_gpu_compatible(true);
VLOG(2) << "node " << n->name() << " is the source of a cpu->gpu copy";
} else {
VLOG(2) << "default alloc case local type " << local_dev_name.type
<< " remote type " << parsed_dst_name.type;
}
}
if (n->IsCollective()) {
// We'll make the sweeping assumption that any collective op is going
// to be involved in network i/o.
attr->set_nic_compatible(true);
}
return s;
}
} // namespace
Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) {
Status s;
DeviceNameUtils::ParsedName local_dev_name = device->parsed_name();
std::vector<const Node*> scoped_allocator_instances;
for (const Node* n : g->nodes()) {
NodeItem* item = node(n->id());
AllocatorAttributes* attrs = item->output_attr_base();
if (IsScopedAllocator(n)) {
scoped_allocator_instances.push_back(n);
}
// Examine the out edges of each node looking for special use
// cases that may affect memory allocation attributes.
for (const auto& e : n->out_edges()) {
if (!e->IsControlEdge()) {
AllocatorAttributes attr;
s = InferAllocAttr(n, e->dst(), local_dev_name, &attr);
if (!s.ok()) return s;
if (attr.value != 0 || attr.scope_id != 0) {
attrs[e->src_output()].Merge(attr);
}
}
}
for (int out = 0; out < n->num_outputs(); out++) {
const OpKernel* op_kernel = item->kernel;
DCHECK_LT(out, op_kernel->output_memory_types().size());
bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
if (on_host) {
AllocatorAttributes h;
h.set_on_host(on_host);
attrs[out].Merge(h);
}
}
}
SetScopedAllocatorAttrs(scoped_allocator_instances);
return s;
}
} // namespace tensorflow

View File

@ -0,0 +1,240 @@
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
#include <memory>
#include <vector>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class Device;
class Graph;
class Node;
class OpKernel;
class Tensor;
// Represents a single data edge in a `NodeItem`.
struct EdgeInfo {
// The node ID of the destination in the containing `GraphView`.
int dst_id;
// The index of the output that produces values on this edge.
int output_slot : 31;
// true if this is the last info for output_slot in the EdgeInfo list.
bool is_last : 1;
// The index of the input that consumes values on this edge.
int input_slot;
};
// Represents a single control edge in a `NodeItem`.
struct ControlEdgeInfo {
// The node ID of the destination in the containing `GraphView`.
int dst_id;
};
// Compact structure representing a graph node and its associated kernel.
//
// Each NodeItem is an element of exactly one GraphView.
struct NodeItem {
// The index of this node's item in its GraphView.
int node_id = -1;
// Cached attributes of this node for fast lookup.
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
bool is_merge : 1; // True iff IsMerge(node)
bool is_enter : 1; // True iff IsEnter(node)
bool is_constant_enter : 1; // True iff IsEnter(node) and
// node->GetAttr("is_constant") == true.
bool is_exit : 1; // True iff IsExit(node)
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
bool is_source : 1; // True iff IsSource(node)
// True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
bool is_enter_exit_or_next_iter : 1;
bool is_transfer_node : 1; // True iff IsTransferNode(node)
bool is_initialization_op : 1; // True iff IsInitializationOp(node)
bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
bool is_next_iteration : 1; // True iff IsNextIteration(node)
bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
bool
is_any_consumer_merge_or_control_trigger : 1; // True iff the destination
// of any output edge is a
// merge or control trigger
// node.
// The kernel for this node.
OpKernel* kernel = nullptr;
// If the kernel is a Const op, this containts points to the constant tensor.
const Tensor* const_tensor = nullptr;
// Cached values of node->num_inputs() and node->num_outputs(), to
// avoid levels of indirection.
int num_inputs;
int num_outputs;
// ExecutorImpl::tensors_[input_start] is the 1st positional input
// for this node.
int input_start = 0;
// Number of output edges, excluding control edges.
int32 num_output_edges;
// Number of output control edges.
int32 num_output_control_edges;
// If non-null, contains an array of num_outputs bools, where the ith bool
// is true if and only if the ith output is consumed by another node.
std::unique_ptr<bool[]> outputs_required;
gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
num_output_edges);
}
gtl::ArraySlice<EdgeInfo> output_edges() const {
return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
}
gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
num_output_control_edges);
}
DataType input_type(int i) const {
DCHECK_LT(i, num_inputs);
return static_cast<DataType>(input_type_base()[i]);
}
DataType output_type(int i) const {
DCHECK_LT(i, num_outputs);
return static_cast<DataType>(output_type_base()[i]);
}
// Return array of per-output allocator attributes.
const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
// Return array of expected input index from which each output should
// be forwarded:
// kNeverForward (-2) for DO NOT FORWARD (must allocate).
// kNoReservation (-1) for no expected forwarding.
// 0... for forward from that input.
const int* forward_from() const { return forward_from_base(); }
string DebugString() const;
private:
friend class GraphView;
NodeItem() {}
// Variable length section starts immediately after *this
// (uint8 is enough for DataType).
// EdgeInfo out_edges[num_output_edges];
// ControlEdgeInfo out_control_edges[num_output_control_edges];
// AllocatorAttributes output_attr[num_outputs];
// int forward_from[num_outputs];
// uint8 input_type[num_inputs];
// uint8 output_type[num_outputs];
// Return pointer to variable length section.
char* var() const {
return const_cast<char*>(reinterpret_cast<const char*>(this) +
sizeof(NodeItem));
}
EdgeInfo* output_edge_base() const {
return reinterpret_cast<EdgeInfo*>(var());
}
ControlEdgeInfo* output_control_edge_base() const {
return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
num_output_edges);
}
AllocatorAttributes* output_attr_base() const {
return reinterpret_cast<AllocatorAttributes*>(
var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) * num_output_control_edges);
}
int* forward_from_base() const {
return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) *
num_output_control_edges +
sizeof(AllocatorAttributes) * num_outputs);
}
uint8* input_type_base() const {
return reinterpret_cast<uint8*>(
var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) * num_output_control_edges +
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
}
uint8* output_type_base() const {
return reinterpret_cast<uint8*>(
var() + sizeof(EdgeInfo) * num_output_edges +
sizeof(ControlEdgeInfo) * num_output_control_edges +
sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
sizeof(uint8) * num_inputs);
}
TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
};
// Immutable view of a Graph organized for efficient execution.
//
// TODO(b/152651962): Add independent unit tests for this class.
class GraphView {
public:
GraphView() : space_(nullptr) {}
~GraphView();
Status Initialize(const Graph* g);
Status SetAllocAttrs(const Graph* g, const Device* device);
void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
NodeItem* node(int32 id) const {
DCHECK_GE(id, 0);
DCHECK_LT(id, num_nodes_);
uint32 offset = node_offsets_[id];
return ((offset == kuint32max)
? nullptr
: reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
}
int32 num_nodes() const { return num_nodes_; }
private:
char* InitializeNode(char* ptr, const Node* n);
size_t NodeItemBytes(const Node* n);
int32 num_nodes_ = 0;
uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
// node_offsets_[id] holds the byte offset for node w/ "id" in space_
char* space_; // NodeItem objects are allocated here
TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_

View File

@ -45,13 +45,13 @@ auto* graph_run_time_usecs_histogram = monitoring::Sampler<0>::New(
auto* graph_run_input_tensor_bytes = monitoring::Sampler<0>::New(
{"/tensorflow/core/graph_run_input_tensor_bytes",
"The size of input tensors in bytes."},
// Power of 2 with bucket count 14 (256G)
{monitoring::Buckets::Exponential(1, 4, 20)});
// Power of 2 with bucket count 14 (256MB)
{monitoring::Buckets::Exponential(1, 4, 14)});
auto* graph_run_output_tensor_bytes = monitoring::Sampler<0>::New(
{"/tensorflow/core/graph_run_output_tensor_bytes",
"The size of output tensors in bytes."},
// Power of 2 with bucket count 14 (256G)
// Power of 2 with bucket count 14 (256MB)
{monitoring::Buckets::Exponential(1, 4, 14)});
auto* graph_unused_outputs = monitoring::Counter<1>::New(
@ -72,8 +72,8 @@ auto* tf_data_bytes_fetched_counter = monitoring::Counter<0>::New(
auto* tf_data_getnext_duration_counter = monitoring::Sampler<0>::New(
{"/tensorflow/data/getnext_duration",
"Microseconds spent fetching an element from tf.data Dataset iterator."},
// Power of 2 with bucket count 14 (256G)
{monitoring::Buckets::Exponential(1, 4, 20)});
// Power of 2 with bucket count 10 (1024 ms)
{monitoring::Buckets::Exponential(1, 2, 10)});
auto* tf_data_elements_counter = monitoring::Counter<1>::New(
"/tensorflow/data/elements", "tf.data elements", "name");

View File

@ -110,14 +110,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
session_metadata_, this);
}
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
}
for (auto d : all_devices->ListDevices()) {
device_set_.AddDevice(d);
}
InitializeDeviceSet();
}
/* static */
@ -214,6 +207,18 @@ Status ProcessFunctionLibraryRuntime::GetDeviceContext(
"function executions");
}
void ProcessFunctionLibraryRuntime::InitializeDeviceSet() {
DeviceMgr const* all_devices = device_mgr_;
if (parent_ != nullptr && parent_->remote_device_mgr() != nullptr) {
all_devices = parent_->remote_device_mgr();
}
device_set_.reset(new DeviceSet);
for (auto d : all_devices->ListDevices()) {
device_set_->AddDevice(d);
}
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) const {
Device* device = nullptr;
@ -225,7 +230,8 @@ FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
}
const auto& iter = flr_map_->find(device);
if (iter == flr_map_->end()) {
LOG(ERROR) << "Could not find device: " << device_name;
VLOG(1) << "Could not find device: " << device_name
<< "in the local process.";
return nullptr;
}
return iter->second.get();
@ -678,7 +684,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
TF_RETURN_IF_ERROR(
SetArgShape(options.input_resource_dtypes_and_shapes, arg_nodes));
TF_RETURN_IF_ERROR(PinArgsAndRets(
options.input_devices, options.output_devices, device_set_, arg_nodes,
options.input_devices, options.output_devices, *device_set_, arg_nodes,
ret_nodes,
options.config_proto.allow_soft_placement() ? default_device : nullptr));
@ -691,7 +697,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
bool control_rets_updated = false;
TF_RETURN_IF_ERROR(FunctionOptimizationPassRegistry::Global().Run(
device_set_, options.config_proto, &graph, &data->lib_def_,
*device_set_, options.config_proto, &graph, &data->lib_def_,
&control_ret_node_names, &control_rets_updated));
if (control_rets_updated) {
@ -714,7 +720,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
optimization_options.session_options = &session_options;
optimization_options.graph = &graph;
optimization_options.flib_def = &data->lib_def_;
optimization_options.device_set = &device_set_;
optimization_options.device_set = device_set_.get();
optimization_options.is_function_graph = true;
DumpGraph("Before running PRE_PLACEMENT passes", graph.get());
@ -725,7 +731,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
// exceptions/warnings in case where nested function call options are ignored.
DumpGraph("Before calling Placer", graph.get());
Placer placer(graph.get(), function_name, optimization_options.flib_def,
&device_set_, default_device,
device_set_.get(), default_device,
options.config_proto.allow_soft_placement(),
options.config_proto.log_device_placement());
TF_RETURN_IF_ERROR(placer.Run());
@ -741,7 +747,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
DumpGraph("Before running graph optimization fn", graph.get());
Status status = options.optimize_graph_fn(
std::move(ret_node_names), std::move(control_ret_node_names),
&data->lib_def_, device_set_, cpu_device, &graph);
&data->lib_def_, *device_set_, cpu_device, &graph);
if (!status.ok()) {
LOG(WARNING) << "Ignoring multi-device function optimization failure: "
<< status.ToString();
@ -765,7 +771,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
TF_RETURN_IF_ERROR(
PartitionFunctionGraph(device_set_, std::move(graph), &subgraphs));
PartitionFunctionGraph(*device_set_, std::move(graph), &subgraphs));
for (const auto& pair : subgraphs) {
DumpGraph(strings::StrCat("Before running POST_PARTITIONING passes (",
@ -841,7 +847,7 @@ Status ProcessFunctionLibraryRuntime::InstantiateMultiDevice(
const string& target = pair.first;
const string& device_type =
device_set_.FindDeviceByName(target)->device_type();
device_set_->FindDeviceByName(target)->device_type();
Graph* subgraph = pair.second.get();
status->Update(UpdateArgAndRetvalMetadata(
@ -1258,12 +1264,18 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle(
FunctionLibraryRuntime::DoneCallback
ProcessFunctionLibraryRuntime::ApplyCleanUpToDoneCallback(
std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done,
const Rendezvous* rendezvous) const {
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
const Rendezvous* created_rendezvous) const {
return
[this, items, done = std::move(done), rendezvous](const Status& status) {
if (rendezvous) {
rendezvous->Unref();
[this, items, done = std::move(done), step_id,
created_rendezvous](const Status& status) {
if (created_rendezvous) {
DCHECK(rendezvous_factory_);
created_rendezvous->Unref();
Status s = rendezvous_factory_.CleanUp(step_id);
if (!s.ok()) {
LOG(ERROR) << s;
}
}
auto* local_status = new Status(status);
CleanUp(items, [local_status, done](const Status& cleanup_status) {
@ -1281,15 +1293,16 @@ void ProcessFunctionLibraryRuntime::Run(
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done) const {
FunctionLibraryRuntime::Options new_opts = opts;
Rendezvous* rendezvous = nullptr;
Rendezvous* created_rendezvous = nullptr;
if (!opts.rendezvous) {
if (rendezvous_factory_) {
Status s = rendezvous_factory_(opts.step_id, device_mgr_, &rendezvous);
Status s =
rendezvous_factory_(opts.step_id, device_mgr_, &created_rendezvous);
if (!s.ok()) {
done(s);
return;
}
new_opts.rendezvous = rendezvous;
new_opts.rendezvous = created_rendezvous;
} else {
done(
errors::FailedPrecondition("The caller does not provide a rendezvous "
@ -1301,7 +1314,8 @@ void ProcessFunctionLibraryRuntime::Run(
}
auto* cleanup_items = new std::vector<std::unique_ptr<CleanUpItem>>;
done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done), rendezvous);
done = ApplyCleanUpToDoneCallback(cleanup_items, std::move(done),
new_opts.step_id, created_rendezvous);
bool multi_device;
{
tf_shared_lock l(mu_);

View File

@ -71,7 +71,7 @@ class ProcessFunctionLibraryRuntime {
DistributedFunctionLibraryRuntime* parent = nullptr,
const CustomKernelCreator* custom_kernel_creator = nullptr,
const SessionMetadata* session_metadata = nullptr,
Rendezvous::Factory rendezvous_factory = nullptr);
Rendezvous::Factory rendezvous_factory = Rendezvous::Factory());
virtual ~ProcessFunctionLibraryRuntime() {
// Deleting the FunctionLibraryRuntime map will delete the function handles
@ -191,7 +191,10 @@ class ProcessFunctionLibraryRuntime {
const DeviceMgr* device_mgr() { return device_mgr_; }
const DeviceSet* device_set() { return &device_set_; }
const DeviceSet* device_set() { return device_set_.get(); }
// Initialize the set of local and remote devices for op device selection.
void InitializeDeviceSet();
const ConfigProto* config() const { return config_ ? &(*config_) : nullptr; }
@ -294,7 +297,7 @@ class ProcessFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback ApplyCleanUpToDoneCallback(
std::vector<std::unique_ptr<CleanUpItem>>* items,
FunctionLibraryRuntime::DoneCallback done,
FunctionLibraryRuntime::DoneCallback done, const int64 step_id,
const Rendezvous* rendezvous) const;
DistributedFunctionLibraryRuntime* const parent_;
@ -422,7 +425,7 @@ class ProcessFunctionLibraryRuntime {
Env* const env_;
const absl::optional<const ConfigProto> config_;
const DeviceMgr* const device_mgr_;
DeviceSet device_set_;
std::unique_ptr<DeviceSet> device_set_;
const FunctionLibraryDefinition* lib_def_;
thread::ThreadPool* default_thread_pool_;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include <memory>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/device_factory.h"
@ -122,10 +123,24 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts,
/*thread_pool=*/nullptr, cluster_flr_.get(),
/*custom_kernel_creator=*/nullptr, session_metadata,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}));
Rendezvous::Factory{
[this](const int64 step_id, const DeviceMgr* device_mgr,
Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
if (rendezvous_ref_counts_.find(step_id) !=
rendezvous_ref_counts_.end()) {
rendezvous_ref_counts_[step_id]++;
} else {
rendezvous_ref_counts_[step_id] = 1;
}
return Status::OK();
},
[this](const int64 step_id) {
CHECK(rendezvous_ref_counts_.find(step_id) !=
rendezvous_ref_counts_.end());
rendezvous_ref_counts_[step_id]--;
return Status::OK();
}}));
}
Status Instantiate(
@ -289,6 +304,9 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<TestClusterFLR> cluster_flr_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
// To ensure that we are cleaning up the rendezvous properly.
std::unordered_map<int64, int> rendezvous_ref_counts_;
};
TEST_F(ProcessFunctionLibraryRuntimeTest, GetFLRNull) {
@ -362,6 +380,9 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
test::ExpectTensorEqual<tstring>(
y, test::AsTensor<tstring>({"/job:a/replica:0/task:0/device:CPU:0"},
TensorShape({})));
EXPECT_EQ(1, rendezvous_ref_counts_.size());
EXPECT_EQ(opts.step_id, rendezvous_ref_counts_.begin()->first);
EXPECT_EQ(0, rendezvous_ref_counts_.begin()->second);
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {

View File

@ -61,13 +61,14 @@ Status Dataset::FromGraph(Params params, const GraphDef& graph_def,
/*thread_pool=*/nullptr, /*parent=*/nullptr,
/*custom_kernel_creator=*/nullptr,
/*session_metadata=*/nullptr,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
});
Rendezvous::Factory{
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}});
string fetch_node = "";
for (auto node : graph_def.node()) {
for (const auto& node : graph_def.node()) {
if (node.op() == "_Retval") {
fetch_node = node.input(0);
}

View File

@ -378,8 +378,8 @@ Status EagerServiceImpl::ExecuteOp(const Operation& operation,
return errors::InvalidArgument("Invalid TensorProto: ",
input.tensor().DebugString());
} else {
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context, &handle));
handle = TensorHandle::CreateLocalHandle(std::move(tensor), nullptr,
nullptr, eager_context);
op->AddInput(handle);
}
}
@ -558,9 +558,8 @@ Status EagerServiceImpl::SendTensor(const SendTensorOp& send_tensor,
return errors::InvalidArgument("Unable to parse tensor proto");
}
TensorHandle* tensor_handle = nullptr;
TF_RETURN_IF_ERROR(TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context, &tensor_handle));
TensorHandle* tensor_handle = TensorHandle::CreateLocalHandle(
std::move(tensor), nullptr, nullptr, eager_context);
TensorHandle* copied_handle = nullptr;
Device* device;
TF_RETURN_IF_ERROR(eager_context->FindDeviceFromName(

View File

@ -162,8 +162,8 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
in.op_device().empty() ? in.device() : in.op_device();
TF_RETURN_IF_ERROR(
parent_->FindDeviceFromName(device_name.c_str(), &device));
TF_RETURN_IF_ERROR(TensorHandle::CreateLazyRemoteHandle(
in.op_id(), in.output_num(), in.dtype(), device, parent_, out));
*out = TensorHandle::CreateLazyRemoteHandle(in.op_id(), in.output_num(),
in.dtype(), device, parent_);
TensorHandle::ResourceHandleInfo resource_handle_info;
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;

View File

@ -70,9 +70,8 @@ TEST_F(RemoteMgrTest, SerializeLocalTensorHandleWithRemoteMirror) {
RemoteMgr remote_mgr(false, ctx_);
Tensor t(DT_FLOAT, TensorShape({0}));
TensorHandle* handle;
TF_ASSERT_OK(TensorHandle::CreateLocalHandle(std::move(t), local_device_,
local_device_, ctx_, &handle));
TensorHandle* handle = TensorHandle::CreateLocalHandle(
std::move(t), local_device_, local_device_, ctx_);
const uint64 op_id = 2;
const int output_num = 3;
TF_ASSERT_OK(handle->AddUnshapedRemoteMirror(remote_device_, op_id,
@ -91,10 +90,9 @@ TEST_F(RemoteMgrTest, SerializeRemoteTensorHandle) {
const uint64 op_id = 3;
const int output_num = 1;
TensorHandle* handle;
TF_ASSERT_OK(TensorHandle::CreateUnshapedRemoteHandle(
TensorHandle* handle = TensorHandle::CreateUnshapedRemoteHandle(
op_id, output_num,
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_, &handle));
/*remote_task=*/"", DT_FLOAT, remote_device_, ctx_);
RemoteTensorHandle remote_handle;
TF_ASSERT_OK(remote_mgr.SerializeRemoteTensorHandle(
handle, &remote_handle, remote_device_, remote_device_->name()));

View File

@ -67,7 +67,7 @@ GraphMgr::GraphMgr(const WorkerEnv* worker_env, DeviceMgr* device_mgr)
}
GraphMgr::~GraphMgr() {
for (auto p : table_) p.second->Unref();
for (const auto& p : table_) p.second->Unref();
}
GraphMgr::Item::~Item() {
@ -141,13 +141,18 @@ Status GraphMgr::InitItem(
gdef.versions().producer(), item->lib_def.get(),
graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
/*custom_kernel_creator=*/nullptr, /*session_metadata=*/nullptr,
[this, session](const int64 step_id, const DeviceMgr*,
Rendezvous** r) -> Status {
auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
TF_RETURN_IF_ERROR(remote_r->Initialize(session));
*r = remote_r;
return Status::OK();
}));
Rendezvous::Factory{
[this, session](const int64 step_id, const DeviceMgr*,
Rendezvous** r) -> Status {
auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
TF_RETURN_IF_ERROR(remote_r->Initialize(session));
*r = remote_r;
return Status::OK();
},
[this](const int64 step_id) {
this->worker_env_->rendezvous_mgr->Cleanup(step_id);
return Status::OK();
}}));
// Constructs the graph out of "gdef".
Graph graph(OpRegistry::Global());

View File

@ -387,45 +387,28 @@ void OpKernelContext::SetStatus(const Status& status) {
}
Status OpKernelContext::input(StringPiece name, const Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
if (input_is_ref(start)) {
int index;
TF_RETURN_IF_ERROR(get_input_index(name, &index));
if (input_is_ref(index)) {
return errors::InvalidArgument("OpKernel used ref input name '", name,
"' when non-ref input was expected");
}
*tensor = (*params_->inputs)[start].tensor;
*tensor = (*params_->inputs)[index].tensor;
return Status::OK();
}
Status OpKernelContext::input_dtype(StringPiece name, DataType* dtype) const {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
const TensorValue& value((*params_->inputs)[start]);
int index;
TF_RETURN_IF_ERROR(get_input_index(name, &index));
const TensorValue& value((*params_->inputs)[index]);
*dtype = value.dtype();
return Status::OK();
}
Status OpKernelContext::input_ref_mutex(StringPiece name, mutex** out_mutex) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
*out_mutex = input_ref_mutex(start);
int index;
TF_RETURN_IF_ERROR(get_input_index(name, &index));
*out_mutex = input_ref_mutex(index);
return Status::OK();
}
@ -497,23 +480,9 @@ bool OpKernelContext::forward_input_to_output_with_shape(
Status OpKernelContext::forward_input_to_output_with_shape(
StringPiece input_name, StringPiece output_name,
const TensorShape& output_shape, Tensor** output) {
int input_index, output_index, stop;
TF_RETURN_IF_ERROR(
params_->op_kernel->InputRange(input_name, &input_index, &stop));
if (stop != input_index + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
input_name,
"' when single-valued input was "
"expected");
}
TF_RETURN_IF_ERROR(
params_->op_kernel->OutputRange(output_name, &output_index, &stop));
if (stop != output_index + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
output_name,
"' when single-valued output was "
"expected");
}
int input_index, output_index;
TF_RETURN_IF_ERROR(get_input_index(input_name, &input_index));
TF_RETURN_IF_ERROR(get_output_index(output_name, &output_index));
if (!forward_input_to_output_with_shape(input_index, output_index,
output_shape, output)) {
return errors::FailedPrecondition("OpKernel could not forward input '",
@ -621,23 +590,18 @@ void OpKernelContext::delete_ref_input(int index, bool lock_held) {
Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
bool lock_held) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
if (!input_is_ref(start)) {
int index;
TF_RETURN_IF_ERROR(get_input_index(name, &index));
if (!input_is_ref(index)) {
return errors::InvalidArgument("OpKernel used non-ref input name '", name,
"' when ref input was expected");
}
// return a copy of the Ref acquired while holding the mutex
if (lock_held) {
*tensor = *(*params_->inputs)[start].tensor;
*tensor = *(*params_->inputs)[index].tensor;
} else {
tf_shared_lock l(*input_ref_mutex(start));
*tensor = *(*params_->inputs)[start].tensor;
tf_shared_lock l(*input_ref_mutex(index));
*tensor = *(*params_->inputs)[index].tensor;
}
return Status::OK();
}
@ -645,18 +609,13 @@ Status OpKernelContext::mutable_input(StringPiece name, Tensor* tensor,
Status OpKernelContext::replace_ref_input(StringPiece name,
const Tensor& tensor,
bool lock_held) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was expected");
}
if (!input_is_ref(start)) {
int index;
TF_RETURN_IF_ERROR(get_input_index(name, &index));
if (!input_is_ref(index)) {
return errors::InvalidArgument("OpKernel used immutable input name '", name,
"' when ref input was expected");
}
replace_ref_input(start, tensor, lock_held);
replace_ref_input(index, tensor, lock_held);
return Status::OK();
}
@ -888,7 +847,22 @@ Status OpKernelContext::allocate_persistent(DataType type,
return s;
}
Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
Status OpKernelContext::get_input_index(StringPiece name,
int* out_index) const {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->InputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued input name '",
name,
"' when single-valued input was "
"expected");
}
*out_index = start;
return Status::OK();
}
Status OpKernelContext::get_output_index(StringPiece name,
int* out_index) const {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
@ -897,22 +871,31 @@ Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
"' when single-valued output was "
"expected");
}
set_output(start, tensor);
*out_index = start;
return Status::OK();
}
void OpKernelContext::set_output(int index, const Tensor& tensor) {
CHECK_GE(index, 0);
CHECK_LT(index, outputs_.size());
const DataType type = params_->op_kernel->output_type(index);
CHECK(!IsRefType(type));
CHECK_EQ(mutable_output(index), nullptr);
Status OpKernelContext::set_output(StringPiece name, const Tensor& tensor) {
int index;
TF_RETURN_IF_ERROR(get_output_index(name, &index));
set_output(index, tensor);
return Status::OK();
}
Status OpKernelContext::set_output(StringPiece name, Tensor&& tensor) {
int index;
TF_RETURN_IF_ERROR(get_output_index(name, &index));
set_output(index, std::move(tensor));
return Status::OK();
}
bool OpKernelContext::maybe_set_output_by_allocate_and_copy(
int index, const Tensor& tensor) {
bool allocate_and_copy = false;
const bool never_forward =
(params_->forward_from_array != nullptr &&
params_->forward_from_array[index] == Params::kNeverForward);
if (never_forward) {
if (TF_PREDICT_FALSE(never_forward)) {
maybe_initialize_scope_id_set();
if (allocated_scope_ids_->find(output_alloc_attr(index).scope_id) ==
allocated_scope_ids_->end()) {
@ -929,7 +912,7 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
}
}
if (allocate_and_copy) {
if (TF_PREDICT_FALSE(allocate_and_copy)) {
// This output was marked to not be forwarded either during graph
// construction or grappler passes. Force an allocation and copy input to
// output.
@ -939,31 +922,59 @@ void OpKernelContext::set_output(int index, const Tensor& tensor) {
<< params_->forward_from_array[index] << " alloc_attr.scope_id "
<< output_alloc_attr(index).scope_id;
auto new_tensor = MakeUnique<Tensor>();
Status s = allocate_tensor(type, tensor.shape(), new_tensor.get(),
Status s = allocate_tensor(tensor.dtype(), tensor.shape(), new_tensor.get(),
output_alloc_attr(index));
TF_CHECK_OK(s);
device()->CopyTensorInSameDevice(&tensor, new_tensor.get(),
op_device_context(), [](const Status&) {});
outputs_[index] = TensorValue(new_tensor.release());
} else {
}
return allocate_and_copy;
}
void OpKernelContext::maybe_track_allocations_for_set_output(
const Tensor& tensor) {
if (TF_PREDICT_FALSE(track_allocations()) && tensor.TotalBytes() > 0) {
DCHECK(tracking_state_);
mutex_lock l(tracking_state_->stats_mu);
const auto it = std::find_if(
tracking_state_->temp_tensor_buffer_and_size.begin(),
tracking_state_->temp_tensor_buffer_and_size.end(),
[&tensor](const std::pair<const void*, int64>& e) {
return e.first ==
static_cast<const void*>(tensor.tensor_data().data());
});
if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
tracking_state_->temp_memory_allocated -= it->second;
tracking_state_->temp_tensor_buffer_and_size.erase(it);
}
}
}
void OpKernelContext::set_output(int index, const Tensor& tensor) {
CHECK_GE(index, 0);
CHECK_LT(index, outputs_.size());
const DataType type = params_->op_kernel->output_type(index);
CHECK(!IsRefType(type));
CHECK_EQ(outputs_[index].tensor, nullptr);
if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
// Input can be forwarded to output; incref on `tensor` and set output at
// `index` to this tensor.
outputs_[index] = TensorValue(new Tensor(tensor));
if (track_allocations() && tensor.TotalBytes() > 0) {
DCHECK(tracking_state_);
mutex_lock l(tracking_state_->stats_mu);
const auto it = std::find_if(
tracking_state_->temp_tensor_buffer_and_size.begin(),
tracking_state_->temp_tensor_buffer_and_size.end(),
[&tensor](const std::pair<const void*, int64>& e) {
return e.first ==
static_cast<const void*>(tensor.tensor_data().data());
});
if (it != tracking_state_->temp_tensor_buffer_and_size.end()) {
tracking_state_->temp_memory_allocated -= it->second;
tracking_state_->temp_tensor_buffer_and_size.erase(it);
}
}
maybe_track_allocations_for_set_output(*outputs_[index].tensor);
}
}
void OpKernelContext::set_output(int index, Tensor&& tensor) {
CHECK_GE(index, 0);
CHECK_LT(index, outputs_.size());
const DataType type = params_->op_kernel->output_type(index);
CHECK(!IsRefType(type));
CHECK_EQ(outputs_[index].tensor, nullptr);
if (TF_PREDICT_TRUE(!maybe_set_output_by_allocate_and_copy(index, tensor))) {
// Input can be forwarded to output; set output at `index` to this tensor.
outputs_[index] = TensorValue(new Tensor(std::move(tensor)));
maybe_track_allocations_for_set_output(*outputs_[index].tensor);
}
}
@ -977,28 +988,16 @@ void OpKernelContext::set_output_ref(int index, mutex* mu,
Status OpKernelContext::set_output_ref(StringPiece name, mutex* mu,
Tensor* tensor_for_ref) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
set_output_ref(start, mu, tensor_for_ref);
int index;
TF_RETURN_IF_ERROR(get_output_index(name, &index));
set_output_ref(index, mu, tensor_for_ref);
return Status::OK();
}
Status OpKernelContext::mutable_output(StringPiece name, Tensor** tensor) {
int start, stop;
TF_RETURN_IF_ERROR(params_->op_kernel->OutputRange(name, &start, &stop));
if (stop != start + 1) {
return errors::InvalidArgument("OpKernel used list-valued output name '",
name,
"' when single-valued output was "
"expected");
}
*tensor = mutable_output(start);
int index;
TF_RETURN_IF_ERROR(get_output_index(name, &index));
*tensor = mutable_output(index);
return Status::OK();
}

View File

@ -492,6 +492,7 @@ class OpOutputList {
DataType expected_output_dtype(int i) const;
Status allocate(int i, const TensorShape& shape, Tensor** output);
void set(int i, const Tensor& tensor);
void set(int i, Tensor&& tensor);
void set_ref(int i, mutex* mu, Tensor* tensor_for_ref);
int size() const { return stop_ - start_; }
Iterator begin() const { return Iterator(this, 0); }
@ -1031,6 +1032,9 @@ class OpKernelContext {
// REQUIRES: 'tensor' must have the same MemoryType as
// output_memory_types[index]. See comment above.
Status set_output(StringPiece name, const Tensor& tensor);
Status set_output(StringPiece name, Tensor&& tensor);
void set_output(int index, const Tensor& tensor);
void set_output(int index, Tensor&& tensor);
// To output a reference. Caller retains ownership of mu and tensor_for_ref,
// and they must outlive all uses within the step. See comment above.
@ -1198,7 +1202,6 @@ class OpKernelContext {
// The following functions all have versions that return Status
// to capture error conditions, and are strongly preferred.
Tensor* mutable_output(int index);
void set_output(int index, const Tensor& tensor);
mutex* input_ref_mutex(int index);
void set_output_ref(int index, mutex* mu, Tensor* tensor_for_ref);
TensorValue release_output(int index);
@ -1274,6 +1277,16 @@ class OpKernelContext {
Tensor* out_tensor, AllocatorAttributes allocator_attr,
const AllocationAttributes& allocation_attr);
// Helpers for `set_output()`.
// Returns `true` if the tensor was copied into an allocated output.
bool maybe_set_output_by_allocate_and_copy(int index, const Tensor& tensor);
void maybe_track_allocations_for_set_output(const Tensor& tensor);
Status get_input_index(StringPiece name, int* out_index) const;
Status get_output_index(StringPiece name, int* out_index) const;
// Initialize the allocated_scope_ids_ set the first time this method is
// called.
void maybe_initialize_scope_id_set();
@ -1704,6 +1717,12 @@ inline void OpOutputList::set(int i, const Tensor& tensor) {
ctx_->set_output(start_ + i, tensor);
}
inline void OpOutputList::set(int i, Tensor&& tensor) {
DCHECK_GE(i, 0);
DCHECK_LT(i, stop_ - start_);
ctx_->set_output(start_ + i, std::move(tensor));
}
inline void OpOutputList::set_ref(int i, mutex* mu, Tensor* tensor_for_ref) {
DCHECK_GE(i, 0);
DCHECK_LT(i, stop_ - start_);

View File

@ -129,8 +129,43 @@ class RendezvousInterface {
// threads with no clear owner.
class Rendezvous : public RendezvousInterface, public core::RefCounted {
public:
using Factory =
std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>;
class Factory {
public:
// Default to a factory that evaluates to false.
Factory() : valid_(false) {}
Factory(std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>
create_fn,
std::function<Status(const int64)> cleanup_fn)
: valid_(true),
create_fn_(std::move(create_fn)),
cleanup_fn_(std::move(cleanup_fn)) {}
// If no clean up fn is provided, just put in a dummy.
// For backwards compatibility.
explicit Factory(
std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>
create_fn)
: valid_(true),
create_fn_(std::move(create_fn)),
cleanup_fn_([](const int64 step_id) { return Status::OK(); }) {}
explicit operator bool() const { return valid_; }
Status operator()(const int64 step_id, const DeviceMgr* device_mgr,
Rendezvous** rendez) const {
return create_fn_(step_id, device_mgr, rendez);
}
Status CleanUp(const int64 step_id) const { return cleanup_fn_(step_id); }
private:
bool valid_;
std::function<Status(const int64, const DeviceMgr*, Rendezvous**)>
create_fn_;
std::function<Status(const int64)> cleanup_fn_;
};
// Constructs a rendezvous key for the tensor of "name" sent from
// "src_device" to "dst_device". The tensor is generated in the frame
// and iteration specified by "frame_iter".

View File

@ -37,39 +37,49 @@ namespace tensorflow {
namespace {
void RecordPaddingSize(int32 padding_size) {
void RecordPaddingSize(int32 padding_size, const string& model_name) {
static tensorflow::monitoring::PercentileSamplerCell* cell =
tensorflow::monitoring::PercentileSampler<0>::New(
{"/tensorflow/serving/batching/padding_size",
"Tracks the padding size distribution on batches."},
tensorflow::monitoring::PercentileSampler<1>::New(
{"/tensorflow/serving/batching/padding_size", "model_name",
"Tracks the padding size distribution on batches by model_name (if "
"available)."},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber)
->GetCell();
->GetCell(model_name);
cell->Add(static_cast<double>(padding_size));
}
void RecordInputBatchSize(int32 batch_size) {
void RecordInputBatchSize(int32 batch_size, const string& model_name) {
static tensorflow::monitoring::PercentileSamplerCell* cell =
tensorflow::monitoring::PercentileSampler<0>::New(
{"/tensorflow/serving/batching/input_batch_size",
"Tracks the batch size distribution on the inputs."},
tensorflow::monitoring::PercentileSampler<1>::New(
{"/tensorflow/serving/batching/input_batch_size", "model_name",
"Tracks the batch size distribution on the inputs by model_name (if "
"available)."},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, tensorflow::monitoring::UnitOfMeasure::kNumber)
->GetCell();
->GetCell(model_name);
cell->Add(static_cast<double>(batch_size));
}
void RecordBatchDelayMs(int64 batch_delay_ms) {
void RecordBatchDelayMs(int64 batch_delay_ms, const string& model_name) {
static monitoring::PercentileSamplerCell* cell =
monitoring::PercentileSampler<0>::New(
{"/tensorflow/serving/batching/batch_delay_ms",
"Tracks the batching delay for inputs."},
monitoring::PercentileSampler<1>::New(
{"/tensorflow/serving/batching/batch_delay_ms", "model_name",
"Tracks the batching delay for inputs by model_name (if "
"available)."},
/*percentiles=*/{25.0, 50.0, 75.0, 90.0, 95.0, 99.0},
/*max_samples=*/1024, monitoring::UnitOfMeasure::kTime)
->GetCell();
->GetCell(model_name);
cell->Add(static_cast<double>(batch_delay_ms));
}
const string& GetModelName(OpKernelContext* ctx) {
static string* kModelNameUnset = new string("model_name_unset");
if (!ctx->session_metadata()) return *kModelNameUnset;
if (ctx->session_metadata()->name().empty()) return *kModelNameUnset;
return ctx->session_metadata()->name();
}
} // namespace
typedef Eigen::ThreadPoolDevice CPUDevice;
@ -303,7 +313,7 @@ class BatchResource : public ResourceBase {
}
batch_components->inputs.push_back(tensor);
}
RecordInputBatchSize(tensors[0].shape().dim_size(0));
RecordInputBatchSize(tensors[0].shape().dim_size(0), GetModelName(context));
OpInputList captured_tensors;
const auto captured_status =
context->input_list("captured_tensors", &captured_tensors);
@ -386,7 +396,7 @@ class BatchResource : public ResourceBase {
const int padded_batch_size = RoundToLowestAllowedBatchSize(batch.size());
const int padding_amount = padded_batch_size - batch.size();
RecordPaddingSize(padding_amount);
RecordPaddingSize(padding_amount, GetModelName(context));
// All tasks should have the same number of input edges.
const int num_inputs = batch.task(0).inputs.size();
@ -570,8 +580,10 @@ class BatchResource : public ResourceBase {
args.insert(args.end(), captured_inputs.begin(), captured_inputs.end());
uint64 current_time = EnvTime::NowNanos();
const string& model_name = GetModelName(last_task_context);
for (int i = 0; i < batch->num_tasks(); ++i) {
RecordBatchDelayMs((current_time - batch->task(i).start_time) * 1e-6);
RecordBatchDelayMs((current_time - batch->task(i).start_time) * 1e-6,
model_name);
}
// Releases the cleanup method here, because the callback of the function
// library runtime will handle it now.

View File

@ -405,10 +405,11 @@ Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
TF_GRAPH_DEF_VERSION, lib_def_.get(), opts, thread_pool_.get(),
/*parent=*/nullptr, /*custom_kernel_creator=*/nullptr,
/*session_metadata=*/nullptr,
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
});
Rendezvous::Factory{
[](const int64, const DeviceMgr* device_mgr, Rendezvous** r) {
*r = new IntraProcessRendezvous(device_mgr);
return Status::OK();
}});
flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
if (thread_pool_ == nullptr) {
runner_ = [](const std::function<void()>& fn) { fn(); };
@ -548,8 +549,7 @@ Status DatasetOpsTestBase::AddDatasetInput(
inputs->size(), " vs. ", input_types.size());
}
bool is_ref = IsRefType(input_types[inputs->size()]);
std::unique_ptr<Tensor> input =
absl::make_unique<Tensor>(allocator_, dtype, shape);
auto input = absl::make_unique<Tensor>(allocator_, dtype, shape);
if (is_ref) {
DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);

View File

@ -50,7 +50,7 @@ void ArgOp::Compute(OpKernelContext* ctx) {
errors::InvalidArgument("Type mismatch: actual ",
DataTypeString(val.dtype()),
" vs. expect ", DataTypeString(dtype_)));
ctx->set_output(0, val);
ctx->set_output(0, std::move(val));
}
RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@ -279,7 +279,7 @@ class SymbolicGradientOp : public AsyncOpKernel {
" tensor(s), but get ", rets->size(), " tensor(s) instead."));
} else {
for (size_t i = 0; i < rets->size(); ++i) {
ctx->set_output(i, (*rets)[i]);
ctx->set_output(i, std::move((*rets)[i]));
}
}
delete rets;
@ -413,7 +413,7 @@ void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
ctx->SetStatus(status);
} else {
for (size_t i = 0; i < rets->size(); ++i) {
ctx->set_output(i, (*rets)[i]);
ctx->set_output(i, std::move((*rets)[i]));
}
}
delete rets;

View File

@ -100,6 +100,12 @@ struct InTopKFunctor<GPUDevice, T, TargetT> {
errors::InvalidArgument(
"Number of targets * number of classes must be less than INT_MAX"));
if (num_targets == 0 || num_classes == 0) {
// Result is empty, so shortcut the rest of the function to avoid
// launching kernels with empty input.
return;
}
// Temporary storage for a mask computed by `ComputePredictionMaskKernel`.
Tensor predictions_mask;
OP_REQUIRES_OK(

View File

@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_properties.h"
#ifdef GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/common_runtime/gpu/gpu_managed_allocator.h"
@ -137,11 +138,16 @@ Status OpsTestBase::InitOp() {
}
Status OpsTestBase::InitOpWithGraphVersion(int graph_def_version) {
Status status;
kernel_ = CreateOpKernel(device_type_, device_, allocator(), node_def_,
graph_def_version, &status);
if (kernel_ != nullptr) input_types_ = kernel_->input_types();
return status;
std::shared_ptr<const NodeProperties> props;
TF_RETURN_IF_ERROR(NodeProperties::CreateFromNodeDef(
node_def_, OpRegistry::Global(), &props));
OpKernel* kernel;
TF_RETURN_IF_ERROR(CreateOpKernel(
device_type_, device_, allocator(), /*flib=*/nullptr,
device_->resource_manager(), props, graph_def_version, &kernel));
kernel_.reset(kernel);
input_types_ = kernel_->input_types();
return Status::OK();
}
Status OpsTestBase::RunOpKernel() {

View File

@ -224,10 +224,8 @@ REGISTER_CPU(complex128)
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -362,10 +362,8 @@ class DenseToCSRSparseMatrixGPUOp : public AsyncOpKernel {
REGISTER_GPU(GPU, float)
REGISTER_GPU(GPU, double)
#if GOOGLE_CUDA
REGISTER_GPU(GPU, complex64)
REGISTER_GPU(GPU, complex128)
#endif
namespace functor {

View File

@ -538,8 +538,13 @@ class CSRMatMulGPUOp : public CSRMatMulOp<GPUDevice, T> {
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, c_shape, &c_t));
const GPUDevice& d = ctx->eigen_device<GPUDevice>();
if (b_outer_dim == 1) {
bool use_matrix_vector_multiply = (b_outer_dim == 1);
#if TENSORFLOW_USE_ROCM
// ROCm hipsparse does not implement csrmv with transposed input a
use_matrix_vector_multiply =
use_matrix_vector_multiply && !this->transpose_a_;
#endif
if (use_matrix_vector_multiply) {
// Call matrix-vector multiply if b is a vector.
TTypes<int64>::ConstVec a_dense_shape_comp(a_dense_shape.data() + row_dim,
2);

View File

@ -107,10 +107,8 @@ class CSRMulOp : public OpKernel {
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU

View File

@ -120,10 +120,8 @@ REGISTER(CPU, complex128)
REGISTER(GPU, float)
REGISTER(GPU, double)
#if GOOGLE_CUDA
REGISTER(GPU, complex64)
REGISTER(GPU, complex128)
#endif
#undef REGISTER
@ -141,10 +139,8 @@ namespace functor {
DECLARE_GPU_SPEC(int32);
DECLARE_GPU_SPEC(float);
DECLARE_GPU_SPEC(double);
#if GOOGLE_CUDA
DECLARE_GPU_SPEC(complex64);
DECLARE_GPU_SPEC(complex128);
#endif
#undef DECLARE_GPU_SPEC
} // namespace functor

View File

@ -328,10 +328,8 @@ extern template struct COOSparseMatrixToCSRSparseMatrix<GPUDevice>;
REGISTER_GPU(float)
REGISTER_GPU(double)
#if GOOGLE_CUDA
REGISTER_GPU(complex64)
REGISTER_GPU(complex128)
#endif
#undef REGISTER_GPU

View File

@ -50,15 +50,11 @@ class LegacyVar : public ResourceBase {
VariableOp::VariableOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
dtype_ = RemoveRefType(context->output_type(0));
OP_REQUIRES_OK(context, cinfo_.Init(context->resource_manager(), def(),
true /* use name() */));
}
void VariableOp::Compute(OpKernelContext* ctx) {
mutex_lock l(init_mu_);
if (!initialized_) {
OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def(),
true /* use name() */));
initialized_ = true;
}
auto creator = [this](LegacyVar** var) {
*var = new LegacyVar(dtype_);
(*var)->tensor()->set_shape(shape_);

View File

@ -36,10 +36,7 @@ class VariableOp : public OpKernel {
private:
DataType dtype_;
TensorShape shape_;
mutex init_mu_;
ContainerInfo cinfo_ TF_GUARDED_BY(init_mu_);
bool initialized_ TF_GUARDED_BY(init_mu_){false};
ContainerInfo cinfo_;
TF_DISALLOW_COPY_AND_ASSIGN(VariableOp);
};

View File

@ -36,8 +36,22 @@ struct ProfilerOptions {
// DeviceType::kTpu: only CPU/TPU will be profiled.
DeviceType device_type = DeviceType::kUnspecified;
// Inexpensive ops are not traced by default.
int host_tracer_level = 2;
// Levels of host tracing:
// - Level 0 is used to disable host traces.
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
// level program execution details (expensive TF ops, XLA ops, etc).
// This is the default.
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
// (low-level) program execution details (cheap TF ops, etc).
uint32 host_tracer_level = 2;
// Levels of device tracing:
// - Level 0 is used to disable device traces.
// - Level 1 is used to enable device traces.
// - More levels might be defined for specific device for controlling the
// verbosity of the trace.
uint32 device_tracer_level = 1;
// Whether to enable python function calls tracer.
bool enable_python_tracer = false;

View File

@ -14,11 +14,37 @@ service ProfilerService {
}
message ProfileOptions {
// Some default value of option are not proto3 default value. Use this version
// to determine if we should use default option value instead of proto3
// default value.
uint32 version = 5;
// We don't collect the dataset ops by default for better trace-viewer
// scalability. The caller can mannually set this field to include the ops.
bool include_dataset_ops = 1;
// next-field: 2
// Levels of host tracing: (version >= 1)
// - Level 0 is used to disable host traces.
// - Level 1 enables tracing of only user instrumented (or default) TraceMe.
// - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high
// level program execution details (expensive TF ops, XLA ops, etc).
// This is the default.
// - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose
// (low-level) program execution details (cheap TF ops, etc).
uint32 host_tracer_level = 2;
// Levels of device tracing: (version >= 1)
// - Level 0 is used to disable device traces.
// - Level 1 is used to enable device traces.
// - More levels might be defined for specific device for controlling the
// verbosity of the trace.
uint32 device_tracer_level = 3;
// Whether enable python function calls tracing. Runtime overhead ensues if
// enabled. Default off. (version >= 1)
uint32 python_tracer_level = 4;
// next-field: 6
}
message ToolRequestOptions {

View File

@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/env_time.h"
#include "tensorflow/core/profiler/convert/xplane_to_profile_response.h"
#include "tensorflow/core/profiler/internal/profiler_interface.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/util/ptr_util.h"
@ -51,7 +52,8 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
::grpc::Status Profile(::grpc::ServerContext* ctx, const ProfileRequest* req,
ProfileResponse* response) override {
VLOG(1) << "Received a profile request: " << req->DebugString();
std::unique_ptr<ProfilerSession> profiler = ProfilerSession::Create();
std::unique_ptr<ProfilerSession> profiler =
ProfilerSession::Create(GetOptions(req->opts()));
Status status = profiler->Status();
if (!status.ok()) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
@ -74,6 +76,19 @@ class ProfilerServiceImpl : public grpc::ProfilerService::Service {
return ::grpc::Status::OK;
}
private:
profiler::ProfilerOptions GetOptions(const tensorflow::ProfileOptions& opts) {
profiler::ProfilerOptions options;
if (opts.version()) {
options.host_tracer_level = opts.host_tracer_level();
options.device_tracer_level = opts.device_tracer_level();
options.enable_python_tracer = opts.python_tracer_level() > 0;
} else {
// use default options value;
}
return options;
}
};
} // namespace

View File

@ -122,6 +122,8 @@ const StatTypeMap& GetStatTypeMap() {
{"fragmentation", kFragmentation},
{"peak_bytes_in_use", kPeakBytesInUse},
{"requested_bytes", kRequestedBytes},
{"allocation_bytes", kAllocationBytes},
{"addr", kAddress},
{"shape", kTensorShapes},
// Device trace arguments.
{"device_id", kDeviceId},

View File

@ -113,6 +113,8 @@ enum StatType {
kFragmentation,
kPeakBytesInUse,
kRequestedBytes,
kAllocationBytes,
kAddress,
kTensorShapes,
// Device trace arguments.
kDeviceId,

View File

@ -12022,7 +12022,7 @@ func SampleDistortedBoundingBoxV2Seed2(value int64) SampleDistortedBoundingBoxV2
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
// If not specified, defaults to {f:0.75 f:1.33}
// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@ -12033,7 +12033,7 @@ func SampleDistortedBoundingBoxV2AspectRatioRange(value []float32) SampleDistort
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
// If not specified, defaults to {f:0.05 f:1}
// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxV2AreaRange(value []float32) SampleDistortedBoundingBoxV2Attr {
return func(m optionalAttr) {
m["area_range"] = value
@ -12251,7 +12251,7 @@ func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBo
//
// value: The cropped area of the image must have an aspect ratio =
// width / height within this range.
// If not specified, defaults to {f:0.75 f:1.33}
// If not specified, defaults to {f:0.75 f:1.33}
func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["aspect_ratio_range"] = value
@ -12262,7 +12262,7 @@ func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistorted
//
// value: The cropped area of the image must contain a fraction of the
// supplied image within this range.
// If not specified, defaults to {f:0.05 f:1}
// If not specified, defaults to {f:0.05 f:1}
func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
return func(m optionalAttr) {
m["area_range"] = value
@ -19038,7 +19038,7 @@ func ImageSummaryMaxImages(value int64) ImageSummaryAttr {
// ImageSummaryBadColor sets the optional bad_color attribute to value.
//
// value: Color to use for pixels with non-finite values.
// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
// If not specified, defaults to {dtype:DT_UINT8 tensor_shape:{dim:{size:4}} int_val:255 int_val:0 int_val:0 int_val:255}
func ImageSummaryBadColor(value tf.Tensor) ImageSummaryAttr {
return func(m optionalAttr) {
m["bad_color"] = value
@ -20109,7 +20109,7 @@ func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterV2Dilations(value []int64) Conv3DBackpropFilterV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@ -21281,7 +21281,7 @@ func Conv2DBackpropInputDataFormat(value string) Conv2DBackpropInputAttr {
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropInputDilations(value []int64) Conv2DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -21989,7 +21989,7 @@ func Conv2DDataFormat(value string) Conv2DAttr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DDilations(value []int64) Conv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22185,7 +22185,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeOutType(value tf.DataTy
// QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAndRequantizeAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22254,7 +22254,7 @@ func QuantizedDepthwiseConv2DWithBiasAndReluOutType(value tf.DataType) Quantized
// QuantizedDepthwiseConv2DWithBiasAndReluDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasAndReluDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAndReluAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22369,7 +22369,7 @@ func QuantizedDepthwiseConv2DWithBiasOutType(value tf.DataType) QuantizedDepthwi
// QuantizedDepthwiseConv2DWithBiasDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DWithBiasDilations(value []int64) QuantizedDepthwiseConv2DWithBiasAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22428,7 +22428,7 @@ func QuantizedDepthwiseConv2DOutType(value tf.DataType) QuantizedDepthwiseConv2D
// QuantizedDepthwiseConv2DDilations sets the optional dilations attribute to value.
//
// value: List of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedDepthwiseConv2DDilations(value []int64) QuantizedDepthwiseConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22602,7 +22602,7 @@ func QuantizedConv2DPerChannelOutType(value tf.DataType) QuantizedConv2DPerChann
// QuantizedConv2DPerChannelDilations sets the optional dilations attribute to value.
//
// value: list of dilation values.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DPerChannelDilations(value []int64) QuantizedConv2DPerChannelAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -22979,7 +22979,7 @@ func Conv3DBackpropInputV2DataFormat(value string) Conv3DBackpropInputV2Attr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputV2Dilations(value []int64) Conv3DBackpropInputV2Attr {
return func(m optionalAttr) {
m["dilations"] = value
@ -25322,7 +25322,7 @@ func AvgPool3DGrad(scope *Scope, orig_input_shape tf.Output, grad tf.Output, ksi
type Conv3DBackpropFilterAttr func(optionalAttr)
// Conv3DBackpropFilterDilations sets the optional dilations attribute to value.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropFilterDilations(value []int64) Conv3DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -25385,7 +25385,7 @@ func Conv3DDataFormat(value string) Conv3DAttr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DDilations(value []int64) Conv3DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -25636,7 +25636,7 @@ func DepthwiseConv2dNativeBackpropInputDataFormat(value string) DepthwiseConv2dN
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropInputDilations(value []int64) DepthwiseConv2dNativeBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -26120,7 +26120,7 @@ func QuantizedConv2DOutType(value tf.DataType) QuantizedConv2DAttr {
// filter element on that dimension. The dimension order is determined by the
// value of `data_format`, see above for details. Dilations in the batch and
// depth dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func QuantizedConv2DDilations(value []int64) QuantizedConv2DAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -40326,7 +40326,7 @@ func DepthwiseConv2dNativeBackpropFilterDataFormat(value string) DepthwiseConv2d
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeBackpropFilterDilations(value []int64) DepthwiseConv2dNativeBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -45852,7 +45852,7 @@ func Conv2DBackpropFilterDataFormat(value string) Conv2DBackpropFilterAttr {
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func Conv2DBackpropFilterDilations(value []int64) Conv2DBackpropFilterAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -46704,7 +46704,7 @@ func LoadTPUEmbeddingFTRLParameters(scope *Scope, parameters tf.Output, accumula
type Conv3DBackpropInputAttr func(optionalAttr)
// Conv3DBackpropInputDilations sets the optional dilations attribute to value.
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1 i:1}
func Conv3DBackpropInputDilations(value []int64) Conv3DBackpropInputAttr {
return func(m optionalAttr) {
m["dilations"] = value
@ -46775,7 +46775,7 @@ func DepthwiseConv2dNativeDataFormat(value string) DepthwiseConv2dNativeAttr {
// element on that dimension. The dimension order is determined by the value of
// `data_format`, see above for details. Dilations in the batch and depth
// dimensions must be 1.
// If not specified, defaults to {i:1 i:1 i:1 i:1}
// If not specified, defaults to {i:1 i:1 i:1 i:1}
func DepthwiseConv2dNativeDilations(value []int64) DepthwiseConv2dNativeAttr {
return func(m optionalAttr) {
m["dilations"] = value

View File

@ -274,9 +274,9 @@ class OpNode {
return tensorflow::errors::Internal(
"Cannot read from invalid tensor index ", input_index);
}
tensorflow::TensorHandle* handle;
TF_RETURN_IF_ERROR(tensorflow::TensorHandle::CreateLocalHandle(
buffer_map->GetTensor(input_index), &handle));
tensorflow::TensorHandle* handle =
tensorflow::TensorHandle::CreateLocalHandle(
buffer_map->GetTensor(input_index));
op_->MutableInputs()->push_back(handle);
} else {
// If this is a forwardable tensor, we will remove it from the previous

View File

@ -237,6 +237,7 @@ cc_library(
"//tensorflow/lite/delegates/gpu/common:model_transformer",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/gl:api2",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:span",
],
)

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <thread> // NOLINT(build/c++11)
#include <vector>
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/delegates/gpu/api.h"
@ -70,6 +71,28 @@ class Delegate {
options_ = options ? *options : TfLiteGpuDelegateOptionsV2Default();
}
TfLiteDelegate* tflite_delegate() { return &delegate_; }
const TfLiteGpuDelegateOptionsV2& options() const { return options_; }
private:
TfLiteDelegate delegate_ = {
.data_ = reinterpret_cast<void*>(this),
.Prepare = DelegatePrepare,
.CopyFromBufferHandle = nullptr,
.CopyToBufferHandle = nullptr,
.FreeBufferHandle = nullptr,
.flags = kTfLiteDelegateFlagsNone,
};
TfLiteGpuDelegateOptionsV2 options_;
};
// Represent the execution of a subset of nodes on GPU.
class DelegateKernel {
public:
explicit DelegateKernel(const TfLiteGpuDelegateOptionsV2& options)
: options_(options) {}
absl::Status Prepare(TfLiteContext* context,
const TfLiteDelegateParams* delegate_params) {
thread_id_prepare_ = std::this_thread::get_id();
@ -133,20 +156,6 @@ class Delegate {
return builder->Build(&runner_);
}
absl::Status SetInputsAndOutputs(TfLiteContext* context) {
int i = 0;
for (auto index : input_indices_) {
RETURN_IF_ERROR(
runner_->SetInputObject(i++, GetTensorObject(index, context)));
}
i = 0;
for (auto index : output_indices_) {
RETURN_IF_ERROR(
runner_->SetOutputObject(i++, GetTensorObject(index, context)));
}
return absl::OkStatus();
}
absl::Status Invoke(TfLiteContext* context) {
if (thread_id_prepare_ != std::this_thread::get_id()) {
TFLITE_LOG(tflite::TFLITE_LOG_WARNING,
@ -162,6 +171,19 @@ class Delegate {
return runner_->Run();
}
private:
absl::Status SetInputsAndOutputs(TfLiteContext* context) {
for (int i = 0; i < input_indices_.size(); ++i) {
RETURN_IF_ERROR(runner_->SetInputObject(
i, GetTensorObject(input_indices_[i], context)));
}
for (int i = 0; i < output_indices_.size(); ++i) {
RETURN_IF_ERROR(runner_->SetOutputObject(
i, GetTensorObject(output_indices_[i], context)));
}
return absl::OkStatus();
}
ObjectDef GetObjectDef(int index) const {
ObjectDef default_object_def;
default_object_def.data_type = DataType::FLOAT32;
@ -176,9 +198,6 @@ class Delegate {
return MakeCpuMemory(absl::MakeSpan(tensor.data.raw, tensor.bytes));
}
TfLiteDelegate* tflite_delegate() { return &delegate_; }
private:
absl::Status InitializeOpenClApi(GraphFloat32* graph,
std::unique_ptr<InferenceBuilder>* builder,
bool* graph_is_destroyed) {
@ -230,28 +249,20 @@ class Delegate {
return absl::OkStatus();
}
TfLiteDelegate delegate_ = {
reinterpret_cast<void*>(this), // .data_
DelegatePrepare, // .Prepare
nullptr, // .CopyFromBufferHandle
nullptr, // .CopyToBufferHandle
nullptr, // .FreeBufferHandle
kTfLiteDelegateFlagsNone, // .flags
};
TfLiteGpuDelegateOptionsV2 options_;
// Shared across all DelegateKernel instances, passed by the Delegate
// instance.
const TfLiteGpuDelegateOptionsV2& options_;
std::unique_ptr<cl::InferenceEnvironment> cl_environment_;
std::unique_ptr<gl::InferenceEnvironment> gl_environment_;
std::unique_ptr<InferenceRunner> runner_;
std::vector<int64_t> input_indices_;
std::vector<int64_t> output_indices_;
std::thread::id thread_id_prepare_; // thread id used for Prapare()
bool enforce_same_thread_ = false; // flag to enforce same thread for Invoke
};
inline Delegate* GetDelegate(TfLiteNode* node) {
return reinterpret_cast<Delegate*>(node->user_data);
inline DelegateKernel* GetDelegateKernel(TfLiteNode* node) {
return reinterpret_cast<DelegateKernel*>(node->user_data);
}
inline Delegate* GetDelegate(TfLiteDelegate* delegate) {
@ -267,16 +278,20 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
auto* gpu_delegate = GetDelegate(params->delegate);
// Everything below should happen in prepare function call, but TFLite
// for whatever reason forbids that.
const auto status = gpu_delegate->Prepare(context, params);
auto gpu_delegate_kernel =
absl::make_unique<DelegateKernel>(gpu_delegate->options());
const auto status = gpu_delegate_kernel->Prepare(context, params);
if (!status.ok()) {
context->ReportError(context, "TfLiteGpuDelegate Init: %s",
std::string(status.message()).c_str());
return nullptr;
}
return gpu_delegate;
return gpu_delegate_kernel.release();
},
// .free
[](TfLiteContext*, void* buffer) -> void {},
[](TfLiteContext*, void* buffer) -> void {
delete reinterpret_cast<DelegateKernel*>(buffer);
},
// .prepare
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
if (!node->user_data) {
@ -292,7 +307,7 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
},
// .invoke
[](TfLiteContext* context, TfLiteNode* node) -> TfLiteStatus {
const auto status = GetDelegate(node)->Invoke(context);
const auto status = GetDelegateKernel(node)->Invoke(context);
if (!status.ok()) {
context->ReportError(context, "TfLiteGpuDelegate Invoke: %s",
std::string(status.message()).c_str());

View File

@ -119,7 +119,7 @@ class Softmax : public NodeShader {
if (z < $depth$) {
highp vec4 src = $input_data_0[0, 0, z]$;
highp vec4 temp = exp(src) * sum;
$output_data_0[0, 0, z]$ = temp;
$output_data_0[0, 0, z] = temp$;
offset += 32;
}
s++;

View File

@ -173,29 +173,6 @@ objc_library(
],
)
objc_library(
name = "environment_test_lib",
testonly = 1,
srcs = ["environment_test.mm"],
sdk_frameworks = ["XCTest"],
deps = [
":environment",
"//tensorflow/lite/delegates/gpu/metal/kernels:test_util",
],
)
ios_unit_test(
name = "environment_test",
testonly = 1,
minimum_os_version = "10.0",
runner = tflite_ios_lab_runner("IOS_LATEST"),
tags = tf_gpu_tests_tags() + [
"notap",
"tflite_not_portable_android",
],
deps = [":environment_test_lib"],
)
objc_library(
name = "inference_context",
srcs = ["inference_context.mm"],
@ -273,7 +250,6 @@ objc_library(
srcs = [
"//tensorflow/lite/delegates/gpu/metal:common_test.mm",
"//tensorflow/lite/delegates/gpu/metal:compiled_model_test.mm",
"//tensorflow/lite/delegates/gpu/metal:environment_test.mm",
"//tensorflow/lite/delegates/gpu/metal:inference_context_test.mm",
],
hdrs = [

View File

@ -88,12 +88,14 @@ std::vector<ComputeTaskDescriptorPtr> SelectDepthWiseConv(
std::vector<ComputeTaskDescriptorPtr> SelectConvolutionTransposed(
int id, ValueId input_id, ValueId output_id,
const ConvolutionTransposedAttributes& attr,
const ConvolutionTransposedAttributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
if (CheckConvolutionTransposed4x4Support(attr)) {
return ConvolutionTransposed4x4(id, input_id, output_id, attr, options);
return ConvolutionTransposed4x4(id, input_id, output_id, attr, device_info,
options);
} else {
return ConvolutionTransposed(id, input_id, output_id, attr, options);
return ConvolutionTransposed(id, input_id, output_id, attr, device_info,
options);
}
}
@ -165,6 +167,7 @@ bool IsSuitableForWinograd4x4To6x6(const Convolution2DAttributes& attr,
absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
const std::vector<ValueId>& inputs,
const std::vector<ValueId>& outputs,
const DeviceInfo& device_info,
const RuntimeOptions& options,
int* last_node_id, int* last_value_id,
std::vector<ComputeTaskDescriptorPtr>* tasks) {
@ -219,8 +222,9 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
BHWC conv_shape{dst_shape.b, 36, tiles_x * tiles_y, dst_shape.c};
(*last_node_id) += 1;
auto t1 = ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
conv_shape, attr, options);
auto t1 =
ConvolutionWino4x4To6x6(*last_node_id, value_id, value_id + 1,
conv_shape, attr, device_info, options);
tasks->insert(tasks->end(), t1.begin(), t1.end());
Winograd36To4x4Attributes wino_down_attr;
@ -233,7 +237,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
(*last_value_id) += 2;
} else {
*tasks = ConvolutionGeneric(node_id, inputs[0], outputs[0], dst_shape,
attr, options);
attr, device_info, options);
}
break;
}
@ -242,7 +246,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
node_id, inputs[0], outputs[0],
absl::any_cast<ConvolutionTransposedAttributes>(
node->operation.attributes),
options);
device_info, options);
break;
case OperationType::DEPTHWISE_CONVOLUTION:
*tasks =
@ -255,7 +259,7 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
*tasks = FullyConnected(
node_id, inputs[0], outputs[0],
absl::any_cast<FullyConnectedAttributes>(node->operation.attributes),
options);
device_info, options);
break;
case OperationType::MAX_UNPOOLING_2D:
*tasks = MaxUnpooling(
@ -388,7 +392,8 @@ absl::Status RegisterPrimaryOps(const GraphFloat32& graph, const Node* node,
} // namespace
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
const RuntimeOptions& options,
CompiledModel* compiled_model) {
int last_node_id = 0;
for (const auto& node : graph.nodes()) {
@ -412,7 +417,7 @@ absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
RegisterCustomOps(graph, node, inputs, outputs, options, &tasks);
if (!custom_status.ok()) {
auto primary_status =
RegisterPrimaryOps(graph, node, inputs, outputs, options,
RegisterPrimaryOps(graph, node, inputs, outputs, device_info, options,
&last_node_id, &last_value_id, &tasks);
if (!primary_status.ok()) {
return absl::UnimplementedError(

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
#include "tensorflow/lite/delegates/gpu/metal/compiled_model.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -26,7 +27,8 @@ namespace gpu {
namespace metal {
// Builds CompiledModel out of GraphFloat32 graph using provided RuntimeOptions.
absl::Status Compile(const GraphFloat32& graph, const RuntimeOptions& options,
absl::Status Compile(const GraphFloat32& graph, const DeviceInfo& device_info,
const RuntimeOptions& options,
CompiledModel* compiled_model);
} // namespace metal

View File

@ -16,21 +16,53 @@ limitations under the License.
#ifndef TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_
#define TENSORFLOW_LITE_DELEGATES_GPU_METAL_ENVIRONMENT_H_
#include <string>
namespace tflite {
namespace gpu {
namespace metal {
enum class GpuType {
enum class AppleGPU {
kUnknown,
kA7, // iPhone 5s, iPad Air, iPad Mini 2, iPad Mini 3.
kA8, // A8 iPhone 6, A8X iPad Air 2, iPad Mini 4.
kA9, // A9 iPhone 6s, iPad (2017), A9X iPad Pro (1st generation).
kA10, // iPhone 7, iPad (2018), A10X iPad Pro (2nd generation).
kA11, // iPhone 8/X.
kA12, // iPhone Xs.
kA7,
kA8,
kA8X,
kA9,
kA9X,
kA10,
kA10X,
kA11,
kA12,
kA12X,
kA12Z,
kA13,
};
GpuType GetGpuType();
struct AppleGPUInfo {
AppleGPUInfo() = default;
explicit AppleGPUInfo(const std::string& device_name);
AppleGPU gpu_type;
bool IsLocalMemoryPreferredOverGlobal() const;
bool IsBionic() const;
// floating point rounding mode
bool IsRoundToNearestSupported() const;
int GetComputeUnitsCount() const;
};
struct DeviceInfo {
DeviceInfo() = default;
explicit DeviceInfo(const std::string& device_name);
AppleGPUInfo apple_info;
// floating point rounding mode
bool IsRoundToNearestSupported() const;
int GetComputeUnitsCount() const;
};
} // namespace metal
} // namespace gpu

View File

@ -15,82 +15,93 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#import <Metal/Metal.h>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/lite/delegates/gpu/metal/common.h"
#include <map>
#include <string>
namespace tflite {
namespace gpu {
namespace metal {
GpuType GetGpuType() {
int max_feature_set = 0;
#if defined(__IPHONE_9_0) && __IPHONE_OS_VERSION_MIN_REQUIRED >= __IPHONE_9_0
std::vector<std::pair<MTLFeatureSet, int>> features;
if (@available(iOS 8.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v1, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v1, 8);
}
if (@available(iOS 9.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v2, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v2, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v1, 9);
}
if (@available(iOS 10.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v3, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v3, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v2, 9);
}
if (@available(iOS 11.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v4, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v3, 9);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v1, 11);
}
if (@available(iOS 12.0, *)) {
features.emplace_back(MTLFeatureSet_iOS_GPUFamily1_v5, 7);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily2_v5, 8);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily3_v4, 9);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily4_v2, 11);
features.emplace_back(MTLFeatureSet_iOS_GPUFamily5_v1, 12);
}
id<MTLDevice> device = GetBestSupportedMetalDevice();
for (auto &type : features) {
if ([device supportsFeatureSet:type.first]) {
max_feature_set = std::max(max_feature_set, type.second);
}
}
#elif defined(__MAC_10_5) && __MAC_OS_X_VERSION_MIN_REQUIRED >= __MAC_10_5
std::vector<std::pair<MTLFeatureSet, int>> features;
if (@available(macOS 10.15, *)) {
features.emplace_back(MTLFeatureSet_macOS_GPUFamily2_v1, 12);
}
id<MTLDevice> device = GetBestSupportedMetalDevice();
for (auto &type : features) {
if ([device supportsFeatureSet:type.first]) {
max_feature_set = std::max(max_feature_set, type.second);
}
}
#endif
switch (max_feature_set) {
case 7:
return GpuType::kA7;
case 8:
return GpuType::kA8;
case 9:
return GpuType::kA9;
case 10:
return GpuType::kA10;
case 11:
return GpuType::kA11;
case 12:
return GpuType::kA12;
default:
return GpuType::kUnknown;
AppleGPUInfo::AppleGPUInfo(const std::string& device_name) {
const std::map<std::string, AppleGPU> kMapping = {
{"Apple A7 GPU", AppleGPU::kA7},
{"Apple A8 GPU", AppleGPU::kA8},
{"Apple A8X GPU", AppleGPU::kA8X},
{"Apple A9 GPU", AppleGPU::kA9},
{"Apple A9X GPU", AppleGPU::kA9X},
{"Apple A10 GPU", AppleGPU::kA10},
{"Apple A10X GPU", AppleGPU::kA10X},
{"Apple A11 GPU", AppleGPU::kA11},
{"Apple A12 GPU", AppleGPU::kA12},
{"Apple A12X GPU", AppleGPU::kA12X},
{"Apple A12Z GPU", AppleGPU::kA12Z},
{"Apple A13 GPU", AppleGPU::kA13},
};
auto it = kMapping.find(device_name);
if (it != kMapping.end()) {
gpu_type = it->second;
} else {
gpu_type = AppleGPU::kUnknown;
}
}
bool AppleGPUInfo::IsLocalMemoryPreferredOverGlobal() const {
return gpu_type == AppleGPU::kA7 ||
gpu_type == AppleGPU::kA8 ||
gpu_type == AppleGPU::kA8X;
}
bool AppleGPUInfo::IsBionic() const {
return gpu_type == AppleGPU::kA11 ||
gpu_type == AppleGPU::kA12 ||
gpu_type == AppleGPU::kA12X ||
gpu_type == AppleGPU::kA12Z ||
gpu_type == AppleGPU::kA13;
}
bool AppleGPUInfo::IsRoundToNearestSupported() const {
return IsBionic();
}
int AppleGPUInfo::GetComputeUnitsCount() const {
switch (gpu_type) {
case AppleGPU::kA7:
return 4;
case AppleGPU::kA8:
return 4;
case AppleGPU::kA8X:
return 8;
case AppleGPU::kA9:
return 6;
case AppleGPU::kA9X:
return 12;
case AppleGPU::kA10:
return 6;
case AppleGPU::kA10X:
return 12;
case AppleGPU::kA11:
return 3;
case AppleGPU::kA12:
return 4;
case AppleGPU::kA12X:
return 7;
case AppleGPU::kA12Z:
return 8;
case AppleGPU::kA13:
return 4;
case AppleGPU::kUnknown:
return 1;
}
}
DeviceInfo::DeviceInfo(const std::string& device_name) : apple_info(device_name) {}
bool DeviceInfo::IsRoundToNearestSupported() const {
return apple_info.IsRoundToNearestSupported();
}
int DeviceInfo::GetComputeUnitsCount() const {
return apple_info.GetComputeUnitsCount();
}
} // namespace metal

View File

@ -1,49 +0,0 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#import <XCTest/XCTest.h>
#include "tensorflow/lite/delegates/gpu/metal/common.h"
using ::tflite::gpu::metal::GetGpuType;
@interface EnvironmentTest : XCTestCase
@end
@implementation EnvironmentTest
- (void)testCompileTimeOSDetection {
#if IOS_VERSION > 0
XCTAssertTrue(MACOS_VERSION == 0 && TVOS_VERSION == 0, @"IOS_VERSION: %d", int{IOS_VERSION});
#endif
#if MACOS_VERSION > 0
XCTAssertTrue(IOS_VERSION == 0 && TVOS_VERSION == 0, @"MACOS_VERSION: %d", int{MACOS_VERSION});
#endif
#if TVOS_VERSION > 0
XCTAssertTrue(IOS_VERSION == 0 && MACOS_VERSION == 0, @"TVOS_VERSION: %d", int{TVOS_VERSION});
#endif
}
- (void)testGetGpuType {
#if (IOS_VERSION > 0) || (TVOS_VERSION > 0)
auto gpuType = GetGpuType();
XCTAssertTrue(gpuType != GpuType::kUnknown);
#endif
}
@end

View File

@ -833,6 +833,7 @@ objc_library(
"//tensorflow/lite/delegates/gpu/metal:api",
"//tensorflow/lite/delegates/gpu/metal:common",
"//tensorflow/lite/delegates/gpu/metal:compiled_model",
"//tensorflow/lite/delegates/gpu/metal:environment",
"//tensorflow/lite/delegates/gpu/metal:inference_context",
"//tensorflow/lite/delegates/gpu/metal:runtime_options",
"@FP16",

View File

@ -659,32 +659,19 @@ bool IsKernelYIs1(const Convolution2DAttributes& attr) {
attr.padding.appended.h == 0;
}
int GetMaximumPossibleWavesCount(const BHWC& dst_shape, GpuType gpu) {
if (gpu == GpuType::kA7 || gpu == GpuType::kA8) {
int GetMaximumPossibleWavesCount(const AppleGPUInfo& apple_info,
const BHWC& dst_shape) {
if (apple_info.IsLocalMemoryPreferredOverGlobal()) {
return GetGroupsCountForLinearWH(dst_shape, {32, 1, 1}, {1, 1, 1});
} else {
return GetGroupsCountForLinearWHS(dst_shape, {32, 1, 1}, {1, 1, 1});
}
}
int GetCountOfComputeUnits(GpuType gpu) {
if (gpu == GpuType::kA7 || gpu == GpuType::kA8) {
return 4;
} else if (gpu == GpuType::kA9 || gpu == GpuType::kA10) {
return 6;
} else if (gpu == GpuType::kA11) {
return 3;
} else if (gpu == GpuType::kA12) {
return 4;
} else {
// unknown gpu
return 4;
}
}
int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) {
const int max_waves = GetMaximumPossibleWavesCount(dst_shape, gpu);
const int cu_count = GetCountOfComputeUnits(gpu);
int GetRecommendedBlockSize(const AppleGPUInfo& apple_info,
const BHWC& dst_shape) {
const int max_waves = GetMaximumPossibleWavesCount(apple_info, dst_shape);
const int cu_count = apple_info.GetComputeUnitsCount();
if (max_waves >= cu_count * 64) {
return 8;
} else if (max_waves >= cu_count * 32) {
@ -696,8 +683,9 @@ int GetRecommendedBlockSize(const BHWC& dst_shape, GpuType gpu) {
}
}
ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
const BHWC& dst_shape, GpuType gpu) {
ConvParams GetConvParamsForA7A8(const AppleGPUInfo& apple_info,
const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4);
const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4);
@ -711,7 +699,7 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
params.linear_whs = false;
params.work_group_launch_order = int3(0, 1, 2);
int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu);
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
if (blk_total_size >= 4 && (dst_slices % 4 == 0 || dst_slices >= 16)) {
params.block_size.z = 4;
@ -771,14 +759,14 @@ ConvParams GetConvParamsForA7A8(const Convolution2DAttributes& attr,
return params;
}
ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
const BHWC& dst_shape, GpuType gpu) {
ConvParams GetConvParamsForA9AndHigher(const AppleGPUInfo& apple_info,
const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
const int dst_slices = IntegralDivideRoundUp(dst_shape.c, 4);
const int src_slices = IntegralDivideRoundUp(attr.weights.shape.i, 4);
int blk_total_size = GetRecommendedBlockSize(dst_shape, gpu);
bool apple_gpu = gpu == GpuType::kA11 || gpu == GpuType::kA12;
int blk_total_size = GetRecommendedBlockSize(apple_info, dst_shape);
int3 block_size = int3(1, 1, 1);
if (blk_total_size >= 2 && apple_gpu) {
if (blk_total_size >= 2 && apple_info.IsBionic()) {
if (dst_shape.h % 2 != 0 && dst_shape.w % 2 == 0) {
block_size.x = 2;
} else {
@ -816,7 +804,7 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
params.work_group_size = int3(32, 1, 1);
params.work_group_launch_order = int3(0, 1, 2);
}
float precise_threshold = gpu == GpuType::kA12 ? 1.0f : 1.04f;
float precise_threshold = apple_info.IsBionic() ? 1.0f : 1.04f;
float precise_ratio = static_cast<float>(g2) / static_cast<float>(g3);
if (precise_ratio > precise_threshold) {
params.linear_wh = false;
@ -852,13 +840,13 @@ ConvParams GetConvParamsForA9AndHigher(const Convolution2DAttributes& attr,
return params;
}
ConvParams GetConvParams(const Convolution2DAttributes& attr,
ConvParams GetConvParams(const DeviceInfo& device_info,
const Convolution2DAttributes& attr,
const BHWC& dst_shape) {
auto gpu_type = GetGpuType();
if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) {
return GetConvParamsForA7A8(attr, dst_shape, gpu_type);
if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
return GetConvParamsForA7A8(device_info.apple_info, attr, dst_shape);
} else {
return GetConvParamsForA9AndHigher(attr, dst_shape, gpu_type);
return GetConvParamsForA9AndHigher(device_info.apple_info, attr, dst_shape);
}
}
@ -898,8 +886,9 @@ std::pair<uint3, uint3> GetDispatchSizes(const ConvParams& params,
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const metal::RuntimeOptions& options) {
ConvParams params = GetConvParams(attr, dst_shape);
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const metal::RuntimeOptions& options) {
ConvParams params = GetConvParams(device_info, attr, dst_shape);
auto desc = std::make_shared<ComputeTaskDescriptor>();
desc->id = id;
@ -953,7 +942,8 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const RuntimeOptions& options) {
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options) {
const int dst_slices = IntegralDivideRoundUp(attr.weights.shape.o, 4);
ConvParams params;
params.work_group_launch_order = int3(2, 0, 1);
@ -965,8 +955,7 @@ std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
params.different_weights_for_height = true;
params.x_kernel_is_1 = true;
params.y_kernel_is_1 = true;
auto gpu_type = GetGpuType();
if (gpu_type == GpuType::kA7 || gpu_type == GpuType::kA8) {
if (device_info.apple_info.IsLocalMemoryPreferredOverGlobal()) {
params.weights_upload_type = WeightsUploadType::LOCAL_MEM_BY_THREADS;
params.work_group_size = int3(32, 1, 1);
params.block_size = int3(4, 1, 4);

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/common/model.h"
#include "tensorflow/lite/delegates/gpu/common/operations.h"
#include "tensorflow/lite/delegates/gpu/metal/compute_task_descriptor.h"
#include "tensorflow/lite/delegates/gpu/metal/environment.h"
#include "tensorflow/lite/delegates/gpu/metal/runtime_options.h"
namespace tflite {
@ -29,11 +30,13 @@ namespace metal {
std::vector<ComputeTaskDescriptorPtr> ConvolutionGeneric(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const RuntimeOptions& options);
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options);
std::vector<ComputeTaskDescriptorPtr> ConvolutionWino4x4To6x6(
int id, ValueId input_id, ValueId output_id, const BHWC& dst_shape,
const Convolution2DAttributes& attr, const RuntimeOptions& options);
const Convolution2DAttributes& attr, const DeviceInfo& device_info,
const RuntimeOptions& options);
} // namespace metal
} // namespace gpu

Some files were not shown because too many files have changed in this diff Show More