Remove two unused LOG(FATAL) methods from TpuNodeContext.

PiperOrigin-RevId: 322711148
Change-Id: Id928e49c8227e4cbf0982443ed08dd46f767eaea
This commit is contained in:
Wenhao Jia 2020-07-22 20:23:13 -07:00 committed by TensorFlower Gardener
parent 663cd759ad
commit f942f4a240
6 changed files with 69 additions and 101 deletions

View File

@ -198,8 +198,8 @@ struct InputBuffers {
// Builds an InputBuffers object that describes the inputs to the computation.
xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
OpKernelContext* context, const xla::Shape& input_host_shape,
const VariableUpdateMap& variable_updates, TpuNodeContext* node_context,
se::Stream* stream) {
const VariableUpdateMap& variable_updates, xla::Backend* backend,
int device_ordinal, se::Stream* stream) {
profiler::TraceMe trace_me("BuildComputationInputs", /*level=*/2);
OpInputList arg_list;
TF_RETURN_IF_ERROR(context->input_list("args", &arg_list));
@ -274,10 +274,8 @@ xla::StatusOr<std::unique_ptr<InputBuffers>> BuildComputationInputs(
validate_shape(variables[i].index(), *variables[i].var()->tensor()));
}
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
xla::TransferManager* const transfer_manager =
node_context->transfer_manager();
const int device_ordinal = node_context->device_ordinal();
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
xla::TransferManager* const transfer_manager = backend->transfer_manager();
auto input_buffers = absl::make_unique<InputBuffers>(
transfer_manager->HostShapeToDeviceShape(input_host_shape));
@ -411,7 +409,7 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
}
xla::TransferManager* const transfer_manager =
node_context->transfer_manager();
node_context->backend()->transfer_manager();
std::vector<TensorShape> output_tensor_shapes;
output_tensor_shapes.reserve(sub_elements);
@ -434,7 +432,8 @@ xla::StatusOr<std::unique_ptr<OutputBuffers>> AllocateOutputTensors(
TF_RET_CHECK(scoped_buffers.on_host_shape().IsTuple());
TF_RET_CHECK(!xla::ShapeUtil::IsNestedTuple(scoped_buffers.on_host_shape()));
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
se::DeviceMemoryAllocator* const allocator =
node_context->backend()->memory_allocator();
auto output_buffers =
absl::make_unique<OutputBuffers>(std::move(scoped_buffers), allocator);
@ -633,10 +632,11 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
TpuNodeContext::Create(device_ordinal));
profiler::TraceMe trace_me(
[&, device_ordinal] {
return absl::StrCat("TpuExecuteOp#device_ordinal=", device_ordinal,
",id=", context->step_id(),
",iter_num=", context->frame_iter().iter_id, "#");
[device_ordinal, context] {
return profiler::TraceMeEncode(
"TpuExecuteOp", {{"device_ordinal", device_ordinal},
{"id", context->step_id()},
{"iter_num", context->frame_iter().iter_id}});
},
/*level=*/2);
profiler::TraceMe trace_me_init("TPUExecuteOp::Init", /*level=*/2);
@ -649,9 +649,9 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
// Shapes of the inputs and outputs, in xla::Shape form.
const TPUExecutableInfoProto* proto = entry->get().get_executable_info();
xla::TransferManager* const transfer_manager =
node_context->transfer_manager();
CHECK(context->op_device_context());
xla::Backend* const backend = node_context->backend();
xla::TransferManager* const transfer_manager = backend->transfer_manager();
TF_RET_CHECK(context->op_device_context());
se::Stream* stream = context->op_device_context()->stream();
TF_RET_CHECK(proto->input_shapes_size() == 1);
@ -666,8 +666,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
proto->output_tensor_shapes().size()));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<InputBuffers> input_buffers,
BuildComputationInputs(context, host_shape, variable_update_map,
node_context.get(), stream));
BuildComputationInputs(context, host_shape, variable_update_map, backend,
device_ordinal, stream));
// Ideally this should be the host-to-device stream from XlaDeviceContext.
// The particular anti-dependency this is avoiding (why we need a separate
@ -680,11 +680,11 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
// TODO(jmolloy): Add the necessary plumbing to obtain the proper
// host-to-device stream here.
TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
node_context->BorrowStream(device_ordinal));
backend->BorrowStream(device_ordinal));
se::DeviceMemoryAllocator* const allocator = node_context->memory_allocator();
auto shaped_buffer =
input_buffers->ToShapedBuffer(host_shape, allocator, device_ordinal);
se::DeviceMemoryAllocator* const allocator = backend->memory_allocator();
auto shaped_buffer = input_buffers->ToShapedBuffer(std::move(host_shape),
allocator, device_ordinal);
if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
shaped_buffer)) {
TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
@ -733,8 +733,8 @@ Status TPUExecuteOp::DoWork(OpKernelContext* context) {
<< shaped_buffer.ToString();
std::vector<xla::ExecutionInput> input;
input.emplace_back(
xla::ExecutionInput(std::move(input_buffers->buffers), host_shape));
input.emplace_back(xla::ExecutionInput(std::move(input_buffers->buffers),
shaped_buffer.on_host_shape()));
// The buffers to be freed are in the `output` and will be automatically
// freed when it goes out of the scope. In async mode, this means the buffers

View File

@ -62,12 +62,12 @@ static bool tpu_cancellation_terminates_process = false;
static bool tpu_cancellation_closes_chips = true;
// Host-side runtime for transfers between TPU and host.
// TODO(b/161940519): Implement this class.
class HostTransferManager {
public:
using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
explicit HostTransferManager(TpuNodeContext*, xla::Backend*) {}
explicit HostTransferManager(TpuNodeContext* node_context)
: node_context_(node_context) {}
using HostCommmandHandler = xla::TpuExecutable::HostCommandHandler;
// Returns a function to be called when the TPU triggers a host command
// interrupt while executing the current program.
@ -76,8 +76,6 @@ class HostTransferManager {
const std::string& rendezvous_key_base, OpKernelContext* ctx);
private:
TpuNodeContext* node_context_; // not owned
TF_DISALLOW_COPY_AND_ASSIGN(HostTransferManager);
};
@ -417,7 +415,9 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
profiler::TraceMe traceme("TPUExecute", 2);
TF_RET_CHECK(tpu::TpuPlatformInterface::GetRegisteredPlatform() != nullptr);
TF_RET_CHECK(tpu_program != nullptr);
VLOG(1) << "TPUExecute on device " << node_context->tensor_core_location();
VLOG(1) << "TPUExecute on device " << node_context->device_ordinal();
xla::Backend* backend = node_context->backend();
XlaDevice* device =
tensorflow::down_cast<XlaDevice*>(ctx->device()->UnderlyingDevice());
@ -425,19 +425,19 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
// Create a HostTransferManager to handle Send/Recv operations from the TPU.
std::shared_ptr<HostTransferManager> host_transfer_manager =
std::make_shared<HostTransferManager>(node_context);
std::make_shared<HostTransferManager>(node_context, backend);
TF_ASSIGN_OR_RETURN(HostTransferManager::HostCommmandHandler handler,
host_transfer_manager->Initialize(
host_transfers, rendezvous_key_base, ctx));
VLOG(2) << "Cloud TPU: Executing computation on device "
<< node_context->index_on_host();
<< node_context->device_ordinal();
xla::ExecutableRunOptions run_options;
run_options.set_stream(stream);
run_options.set_device_assignment(device_assignment);
run_options.set_rng_seed(rng_seed);
run_options.set_allocator(node_context->memory_allocator());
run_options.set_allocator(backend->memory_allocator());
run_options.set_host_to_device_stream(host_to_device_stream);
const xla::ServiceExecutableRunOptions service_run_options(run_options);
@ -460,7 +460,7 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
TF_ASSIGN_OR_RETURN(
module->input_output_alias_config(),
xla::HloInputOutputAliasConfig::CreateFromProto(
node_context->transfer_manager()->HostShapeToDeviceShape(
backend->transfer_manager()->HostShapeToDeviceShape(
module->config().entry_computation_layout().result_shape()),
hlo_metadata.hlo_module().input_output_alias()));
TF_RET_CHECK(executable.input_shapes().size() == arguments.size());
@ -471,11 +471,11 @@ xla::StatusOr<xla::ExecutionOutput> TPUExecute(
xla::ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
}
TF_RETURN_IF_ERROR(UpdateDynamicInputs(
stream, node_context->memory_allocator(), &arguments, input_shapes));
TF_RETURN_IF_ERROR(UpdateDynamicInputs(stream, backend->memory_allocator(),
&arguments, input_shapes));
auto tpu_executable = absl::make_unique<xla::TpuExecutable>(
tpu_program, std::move(module), handler);
tpu_program, std::move(module), /*host_command_handler=*/handler);
const int32 device_ordinal = node_context->device_ordinal();
CancellationToken token;

View File

@ -182,13 +182,12 @@ cc_library(
":tpu_executor_c_api_hdrs",
":tpu_node_context_c_api_hdrs",
":tpu_platform_interface",
":tpu_transfer_manager_base",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/service:stream_pool",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/tpu:tpu_api",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor/lib",

View File

@ -12,13 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/core/tpu/tpu_api.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
@ -36,6 +33,8 @@ StatusOr<std::unique_ptr<TpuNodeContext>> TpuNodeContext::Create(
tpu::NodeContextApiFn()->TpuNodeContext_CreateFn(device_ordinal,
status.c_status);
if (!status.status().ok()) {
// TpuNodeContext_CreateFn allocates a new XLA_TpuNodeContext regardless of
// status. It needs to be freed if it's not given to a TpuNodeContext below.
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context);
return status.status();
}
@ -46,13 +45,6 @@ TpuNodeContext::~TpuNodeContext() {
tpu::NodeContextApiFn()->TpuNodeContext_FreeFn(node_context_);
}
/* static */
Status TpuNodeContext::Initialize(int device_ordinal) {
StatusHelper status;
TpuNodeContext_Initialize(device_ordinal, status.c_status);
return status.status();
}
/* static */
Status TpuNodeContext::StopChipHeartbeats() {
StatusHelper status;
@ -68,21 +60,20 @@ Status TpuNodeContext::CloseTpuHost() {
}
/* static */
tensorflow::tpu::TpuPlatformInterface* TpuNodeContext::platform() {
Status TpuNodeContext::Initialize(int device_ordinal) {
StatusHelper status;
TpuNodeContext_Initialize(device_ordinal, status.c_status);
return status.status();
}
/* static */
TpuPlatformInterface* TpuNodeContext::platform() {
return TpuPlatformInterface::GetRegisteredPlatform();
}
/* static */
stream_executor::DeviceMemoryAllocator* TpuNodeContext::memory_allocator() {
static stream_executor::StreamExecutorMemoryAllocator* memory_allocator =
new stream_executor::StreamExecutorMemoryAllocator(
platform(),
xla::PlatformUtil::GetStreamExecutors(platform()).ValueOrDie());
return memory_allocator;
}
int TpuNodeContext::device_ordinal() const { return device_ordinal_; }
/* static */
xla::Backend* TpuNodeContext::backend() {
xla::Backend* TpuNodeContext::backend() const {
static xla::Backend* backend =
xla::Backend::CreateBackend(
xla::BackendOptions().set_platform(platform()))
@ -91,21 +82,8 @@ xla::Backend* TpuNodeContext::backend() {
return backend;
}
/* static */
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
int device_ordinal) {
return backend()->BorrowStream(device_ordinal);
}
/* static */
StatusOr<xla::StreamPool::Ptr> TpuNodeContext::BorrowStream(
stream_executor::StreamExecutor* executor) {
return backend()->BorrowStream(executor);
}
/* static */
xla::TransferManager* TpuNodeContext::transfer_manager() {
return xla::TransferManager::GetForPlatform(platform()).ValueOrDie();
stream_executor::StreamExecutor* TpuNodeContext::stream_executor() const {
return backend()->stream_executor(device_ordinal_).ValueOrDie();
}
} // namespace tpu

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/stream_pool.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/stream_executor/device_memory_allocator.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
@ -33,6 +34,11 @@ limitations under the License.
namespace tensorflow {
namespace tpu {
// A TpuNodeContext object represents a specific TPU node (core). The static
// class methods represent host-wide actions.
//
// First call Initialize in a freshly reset system. Then call Create to talk to
// individual nodes.
class TpuNodeContext final {
public:
using Status = stream_executor::port::Status;
@ -47,41 +53,25 @@ class TpuNodeContext final {
}
~TpuNodeContext();
TpuNodeContext(const TpuNodeContext&) = delete;
TpuNodeContext& operator=(const TpuNodeContext&) = delete;
static Status Initialize(int device_ordinal);
static Status StopChipHeartbeats();
static Status CloseTpuHost();
static tensorflow::tpu::TpuPlatformInterface* platform();
static Status Initialize(int device_ordinal);
static stream_executor::DeviceMemoryAllocator* memory_allocator();
static TpuPlatformInterface* platform();
static xla::TransferManager* transfer_manager();
int device_ordinal() const;
static xla::Backend* backend();
xla::Backend* backend() const;
static StatusOr<xla::StreamPool::Ptr> BorrowStream(int device_ordinal);
static StatusOr<xla::StreamPool::Ptr> BorrowStream(
stream_executor::StreamExecutor* executor);
stream_executor::StreamExecutor* stream_executor() {
LOG(FATAL) << "Not implemented yet.";
}
std::string tensor_core_location() { LOG(FATAL) << "Not implemented yet."; }
int index_on_host() { LOG(FATAL) << "Not implemented yet."; }
int device_ordinal() const { return device_ordinal_; }
stream_executor::StreamExecutor* stream_executor() const;
private:
const int device_ordinal_;
XLA_TpuNodeContext* const node_context_;
TF_DISALLOW_COPY_AND_ASSIGN(TpuNodeContext);
};
} // namespace tpu

View File

@ -26,19 +26,20 @@ XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
SE_Status* status);
void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
void TpuNodeContext_StopChipHeartbeats(SE_Status* status);
void TpuNodeContext_CloseTpuHost(SE_Status* status);
void TpuNodeContext_Initialize(int device_ordinal, SE_Status* status);
} // extern "C"
struct TfTpu_NodeContextApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
};
#endif // TENSORFLOW_STREAM_EXECUTOR_TPU_TPU_NODE_CONTEXT_C_API_H_