diff --git a/tensorflow/compiler/xrt/kernels/BUILD b/tensorflow/compiler/xrt/kernels/BUILD index 662a2db636b..0e73ba7ecf3 100644 --- a/tensorflow/compiler/xrt/kernels/BUILD +++ b/tensorflow/compiler/xrt/kernels/BUILD @@ -48,11 +48,14 @@ cc_library( name = "xrt_tpu_ops", srcs = [ "tpu_compile_ops.cc", + "tpu_execute_op.cc", ], visibility = [":friends"], deps = [ + "//tensorflow/compiler/jit:xla_device_no_jit_rewrite_registration", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/xla:debug_options_flags", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -60,7 +63,9 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:dump", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_proto_cc", "//tensorflow/compiler/xrt:xrt_proto_cc", "//tensorflow/compiler/xrt:xrt_utils", @@ -68,9 +73,11 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/lib:traceme", "//tensorflow/core/tpu:tpu_api", "//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_defs", + "//tensorflow/core/tpu:tpu_execute", "//tensorflow/core/tpu/kernels:tpu_compilation_cache_entry", "//tensorflow/core/tpu/kernels:tpu_compilation_cache_interface", "//tensorflow/core/tpu/kernels:tpu_compilation_cache_key", @@ -83,6 +90,7 @@ cc_library( "//tensorflow/core/tpu/kernels:tpu_program_group", "//tensorflow/core/tpu/kernels:tpu_program_group_interface", "//tensorflow/stream_executor:stream_executor_headers", + "@com_google_absl//absl/strings", ], alwayslink = 1, ) diff --git a/tensorflow/compiler/xrt/kernels/tpu_execute_op.cc b/tensorflow/compiler/xrt/kernels/tpu_execute_op.cc new file mode 100644 index 00000000000..f1539fd2d10 --- /dev/null +++ b/tensorflow/compiler/xrt/kernels/tpu_execute_op.cc @@ -0,0 +1,488 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "tensorflow/compiler/jit/xla_device.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/xrt/xrt.pb.h" +#include "tensorflow/compiler/xrt/xrt_memory_manager.h" +#include "tensorflow/compiler/xrt/xrt_metrics.h" +#include "tensorflow/compiler/xrt/xrt_state.h" +#include "tensorflow/compiler/xrt/xrt_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/monitoring/timed.h" +#include "tensorflow/core/platform/casts.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/profiler/lib/traceme.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" +#include "tensorflow/core/tpu/kernels/tpu_op_consts.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group.h" +#include "tensorflow/core/tpu/tpu_configuration.h" +#include "tensorflow/core/tpu/tpu_defs.h" +#include "tensorflow/core/tpu/tpu_execute.h" +#include "tensorflow/stream_executor/stream_executor.h" +#include "tensorflow/stream_executor/stream_executor_internal.h" + +namespace tensorflow { +namespace { + +using tensorflow::tpu::CompilationCacheEntryRef; +using tensorflow::tpu::TpuCompilationCacheEntry; +using tensorflow::tpu::TpuCompilationCacheLookup; +using GetBufferFunction = + std::function>()>; + +// Looks up the input `key` in the compilation cache. +Status GetComputationCacheEntry( + ResourceMgr* rm, int64 key, int core_index_in_replica, + std::unique_ptr* entry) { + profiler::TraceMe trace_me("XRTExecuteOp::LookupProto", /*level=*/2); + TpuCompilationCacheLookup* proto_lookup; + TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(), + tpu::kCompiledProtoCacheResourceName, + &proto_lookup)); + core::ScopedUnref lookup_unref(proto_lookup); + TF_RETURN_IF_ERROR(proto_lookup->Lookup(key, core_index_in_replica, entry)); + return Status::OK(); +} + +std::vector GetDynamicInputInfo( + const TPUExecutableInfoProto& executable_proto) { + std::vector input_is_dynamic; + input_is_dynamic.reserve(executable_proto.input_shapes().size()); + for (int64 i = 0; i < executable_proto.input_shapes().size(); ++i) { + input_is_dynamic.push_back( + !xla::Shape(executable_proto.input_shapes(i)).is_static()); + } + return input_is_dynamic; +} + +xla::StatusOr>> GetChainedOpInputs( + const xrt::XRTChainedExecuteOp& op, + absl::Span> op_inputs, + const TPUExecutableInfoProto& executable_proto) { + if (op.inputs_size() != executable_proto.input_shapes_size()) { + return errors::InvalidArgument( + "Number of inputs does not match executable proto input shapes: ", + op.inputs_size(), " vs. ", executable_proto.input_shapes_size()); + } + + std::vector> input_tuples; + input_tuples.reserve(op.inputs_size()); + for (int i = 0; i < op.inputs_size(); ++i) { + auto& input = op.inputs(i); + const RefPtr& tuple = op_inputs[i]; + // Thanks to the greatness of proto3, there is no way to query for + // explicitly set fields, so the default for output_index (zero) means no + // sub-index. As consequence, the real index is output_index - 1. + if (input.output_index() == 0) { + input_tuples.push_back(tuple); + } else { + XRTTupleAllocation* sub_tuple; + TF_RETURN_IF_ERROR(XRTTupleAllocation::MakeSubBuffer( + tuple.get(), {input.output_index() - 1}, &sub_tuple, + /*alias_parent_allocation=*/true)); + input_tuples.emplace_back(sub_tuple); + } + if (!InputShapeMatches(xla::Shape(executable_proto.input_shapes(i)), + input_tuples.back()->on_host_shape())) { + return errors::InvalidArgument( + "Run-time shape mismatch for XRTExecute argument[", i, "] (", + op.computation_handle(), "). Expected ", + executable_proto.input_shapes(i).DebugString(), "; got ", + tuple->on_host_shape().DebugString()); + } + } + return std::move(input_tuples); +} + +xla::StatusOr GetExecutableAliasConfig( + const tpu::TpuProgramGroup* tpu_program_group, xla::Backend* const backend, + int core_index) { + const TPUExecutableInfoProto& executable = + tpu_program_group->executable_info(core_index); + return xla::HloInputOutputAliasConfig::CreateFromProto( + backend->transfer_manager()->HostShapeToDeviceShape( + xla::Shape(executable.output_shape())), + tpu_program_group->hlo_metadata(core_index) + ->hlo_module() + .input_output_alias()); +} + +xla::StatusOr> AllocateOutputTuple( + tpu::TpuNodeContext* node_context, se::Stream* stream, + absl::Span> input_tuples, + const xla::HloInputOutputAliasConfig& input_output_alias, + xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) { + auto output_shaped_buffer = output_scoped_buffer.release(); + + xla::Shape output_device_shape = output_shaped_buffer.on_device_shape(); + if (!output_device_shape.is_static()) { + TF_RETURN_IF_ERROR( + node_context->backend()->transfer_manager()->ReadDynamicShapes( + stream, &output_shaped_buffer, &output_device_shape)); + } + + XRTTupleAllocation* output_tuple; + xla::Shape output_host_shape = + xla::ShapeUtil::DeviceShapeToHostShape(output_device_shape); + + TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer( + output_shaped_buffer, output_host_shape, output_device_shape, + node_context->backend(), device_ordinal, &output_tuple)); + RefPtr output_tuple_ptr(output_tuple); + + // If the input tuples had to release some buffers in order to provide the + // proper temporary ownership transfer, we patch the holes here by alising the + // buffers from the result tuple. The device address we patch back here, will + // essentially be the same one we carved out in the DoWork() function. + TF_RETURN_IF_ERROR( + RebuildOutputAliases(output_tuple_ptr, input_tuples, input_output_alias)); + + return std::move(output_tuple_ptr); +} + +Status AllocateOutputTensors( + OpKernelContext* context, XRTMemoryManager* memory_manager, + tpu::TpuNodeContext* node_context, se::Stream* stream, + const xrt::XRTExecutionConfig& config_proto, + const TPUExecutableInfoProto& executable_proto, + absl::Span> input_tuples, + const xla::HloInputOutputAliasConfig& input_output_alias, + xla::ScopedShapedBuffer output_scoped_buffer, int device_ordinal) { + TF_ASSIGN_OR_RETURN( + RefPtr output_tuple, + AllocateOutputTuple(node_context, stream, input_tuples, + input_output_alias, std::move(output_scoped_buffer), + device_ordinal)); + return CreateExecuteOutput(context, memory_manager, std::move(output_tuple), + config_proto.return_exploded_tuple()); +} + +xla::StatusOr RunExecutable( + OpKernelContext* context, tpu::TpuNodeContext* node_context, + const TPUExecutableInfoProto& executable, + std::vector arguments, const string& execution_id, + const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group, + xla::Backend* const backend, se::Stream* stream, int core_index, + int device_ordinal, string rendezvous_key_base) { + profiler::TraceMe trace_me("RunExecutable", /*level=*/2); + + // se::StreamExecutor* executor = node->stream_executor(); + + std::unique_ptr device_assignment; + if (executable.has_device_assignment()) { + TF_ASSIGN_OR_RETURN(device_assignment, xla::DeviceAssignment::Deserialize( + executable.device_assignment())); + } + // Ideally this should be the host-to-device stream from XlaDeviceContext. + // The particular anti-dependency this is avoiding (why we need a separate + // transfer stream) is between the executable writing tuple tables and + // TPUExecute()'s deregister_stream; if they come from the same stream pool + // antidependencies will occur. XlaBackend has a different pool of streams + // to the stream->GetOrCreateSubStream() that TPUExecute() uses, so these + // will never refer to the same stream. + TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr, + backend->BorrowStream(device_ordinal)); + const TPUHostTransferInfoProto& host_transfer_info = + tpu_program_group->host_transfer_info(core_index); + TF_ASSIGN_OR_RETURN( + xla::ExecutionOutput output, + TPUExecute(executable, host_transfer_info, + *tpu_program_group->hlo_metadata(core_index), + std::move(arguments), rendezvous_key_base, rng_seed, + node_context, device_assignment.get(), + context->cancellation_manager(), context, stream, + transfer_stream_ptr.get(), + tpu_program_group->tpu_program(core_index))); + + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + + return output; +} + +xla::StatusOr ExecuteTPUProgram( + OpKernelContext* context, tpu::TpuNodeContext* node_context, + XRTMemoryManager* memory_manager, const TPUExecutableInfoProto& executable, + const GetBufferFunction& get_buffers_fn, const string& execution_id, + const uint32 rng_seed, const tpu::TpuProgramGroup* tpu_program_group, + xla::Backend* const backend, se::Stream* stream, int core_index, + int device_ordinal, string rendezvous_key_base) { + auto runfn = [&]() -> xla::StatusOr { + TF_ASSIGN_OR_RETURN(auto arguments, get_buffers_fn()); + return RunExecutable(context, node_context, executable, + std::move(arguments), execution_id, rng_seed, + tpu_program_group, backend, stream, core_index, + device_ordinal, rendezvous_key_base); + }; + return memory_manager->Run( + runfn, backend, device_ordinal, /*requested_free_size=*/0); +} + +// XRTExecuteOp + +class XRTExecuteOp : public AsyncOpKernel { + public: + explicit XRTExecuteOp(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + Status DoWork(OpKernelContext* context); +}; + +XRTExecuteOp::XRTExecuteOp(OpKernelConstruction* context) + : AsyncOpKernel(context, /* is_deferred = */ true) {} + +void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) { + // Schedule onto the default queue, for unbounded concurrency. See b/73520706 + OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); + done(); +} + +Status XRTExecuteOp::DoWork(OpKernelContext* context) { + VLOG(1) << "XRTExecuteOp::Compute"; + + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata)); + const int device_ordinal = metadata->device_ordinal(); + // We are guaranteed that the object underlying TpuNodeContext won't be + // deleted out from under us, while node_context is alive. + TF_ASSIGN_OR_RETURN(std::unique_ptr node_context, + tpu::TpuNodeContext::Create(device_ordinal)); + xla::Backend* const backend = node_context->backend(); + se::Stream* stream = context->op_device_context()->stream(); + + auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteCell()); + profiler::TraceMe trace_me( + [context] { + return profiler::TraceMeEncode("TpuExecuteOp", + {{"step_id", context->step_id()}}); + }, + /*level=*/2); + profiler::TraceMe trace_me_init("XRTExecuteOp::Init", /*level=*/2); + + auto* rm = GetTPUConfigResourceMgr(); + TF_RET_CHECK(rm != nullptr); + + const Tensor& execution_input = context->input(0); + TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape())); + int64 compilation_handle = execution_input.scalar()(); + + const Tensor& execution_config = context->input(1); + TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); + xrt::XRTExecutionConfig config_proto; + TF_RET_CHECK( + config_proto.ParseFromString(execution_config.scalar()())); + + int core_index_in_replica = config_proto.core_index_in_replica(); + bool release_inputs = config_proto.release_input_handles(); + bool release_compilation = config_proto.release_compilation_handle(); + + string rendezvous_key_base = std::to_string(compilation_handle); + std::unique_ptr entry; + TF_RETURN_IF_ERROR(GetComputationCacheEntry(rm, compilation_handle, + core_index_in_replica, &entry)); + + TpuCompilationCacheEntry centry = entry->get(); + const tpu::TpuProgramGroup* tpu_program_group = + tensorflow::down_cast( + centry.tpu_program_group()); + CHECK_NE(tpu_program_group, nullptr); + + if (release_compilation) { + // Process-wide cache of Tpu executables. + tpu::TpuCompilationCacheInterface* cache; + TF_RETURN_IF_ERROR(rm->Lookup( + rm->default_container(), tpu::kCompilationCacheResourceName, &cache)); + core::ScopedUnref cache_unref(cache); + TF_RETURN_IF_ERROR(cache->Release(compilation_handle)); + VLOG(2) << "Released compilation handle " << compilation_handle; + } + + const int core_index = centry.core_index(); + const TPUExecutableInfoProto& executable = + tpu_program_group->executable_info(core_index); + + std::vector input_is_dynamic = GetDynamicInputInfo(executable); + + TF_ASSIGN_OR_RETURN( + xla::HloInputOutputAliasConfig input_output_alias, + GetExecutableAliasConfig(tpu_program_group, backend, core_index)); + TF_ASSIGN_OR_RETURN(std::vector input_coords, + GetComputationInputs(context, "input_handles")); + + RefPtr memory_manager = XRTMemoryManager::Get(rm); + XRTMemoryManager::WorkingSet working_set(memory_manager); + TF_ASSIGN_OR_RETURN( + std::vector> input_tuples, + GetInputTupleAllocations( + input_coords, &working_set, backend, executable.input_shapes_size(), + [&](int64 i) { return xla::Shape(executable.input_shapes(i)); }, + release_inputs)); + auto get_buffers_fn = [&]() { + return GetArgumentsBuffers(input_output_alias, input_tuples, + input_is_dynamic, release_inputs); + }; + trace_me_init.Stop(); + + TF_ASSIGN_OR_RETURN( + xla::ExecutionOutput output, + ExecuteTPUProgram( + context, node_context.get(), memory_manager.get(), executable, + get_buffers_fn, config_proto.execution_instance_key(), + config_proto.rng_seed(), tpu_program_group, backend, stream, + core_index, device_ordinal, rendezvous_key_base)); + + // AllocateComputationOutput writes the output tuple handle to the output + // tensor return value from the Op. + TF_RETURN_IF_ERROR(AllocateOutputTensors( + context, memory_manager.get(), node_context.get(), stream, config_proto, + executable, input_tuples, input_output_alias, output.ConsumeResult(), + device_ordinal)); + return Status::OK(); +} + +class XRTExecuteChainedOp : public AsyncOpKernel { + public: + explicit XRTExecuteChainedOp(OpKernelConstruction* context); + + void ComputeAsync(OpKernelContext* context, DoneCallback done) override; + + private: + Status DoWork(OpKernelContext* context); +}; + +XRTExecuteChainedOp::XRTExecuteChainedOp(OpKernelConstruction* context) + : AsyncOpKernel(context, /* is_deferred = */ true) {} + +void XRTExecuteChainedOp::ComputeAsync(OpKernelContext* context, + DoneCallback done) { + // Schedule onto the default queue, for unbounded concurrency. See b/73520706 + OP_REQUIRES_OK_ASYNC(context, DoWork(context), done); + done(); +} + +Status XRTExecuteChainedOp::DoWork(OpKernelContext* context) { + VLOG(1) << "XRTExecuteChainedOp::Compute"; + const XlaDevice::Metadata* metadata; + TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata)); + const int device_ordinal = metadata->device_ordinal(); + // We are guaranteed that the object underlying TpuNodeContext won't be + // deleted out from under us, while node_context is alive. + TF_ASSIGN_OR_RETURN(std::unique_ptr node_context, + tpu::TpuNodeContext::Create(device_ordinal)); + xla::Backend* const backend = node_context->backend(); + se::Stream* stream = context->op_device_context()->stream(); + auto timed = monitoring::MakeTimed(xrt_metrics::GetExecuteChainedCell()); + profiler::TraceMe trace_me( + [context] { + return profiler::TraceMeEncode("TpuExecuteChainedOp", + {{"step_id", context->step_id()}}); + }, + /*level=*/2); + ResourceMgr* rm = GetTPUConfigResourceMgr(); + TF_RET_CHECK(rm != nullptr); + + const Tensor& execution_plan = context->input(0); + TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_plan.shape())); + xrt::XRTChainedExecutePlan plan; + TF_RET_CHECK(plan.ParseFromString(execution_plan.scalar()())); + + const Tensor& execution_config = context->input(1); + TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape())); + xrt::XRTChainedExecuteConfig config; + TF_RET_CHECK(config.ParseFromString(execution_config.scalar()())); + + TpuCompilationCacheLookup* proto_lookup; + TF_RETURN_IF_ERROR(rm->Lookup(rm->default_container(), + tpu::kCompiledProtoCacheResourceName, + &proto_lookup)); + core::ScopedUnref lookup_unref(proto_lookup); + RefPtr memory_manager = XRTMemoryManager::Get(rm); + auto execute_op = [&](const xrt::XRTChainedExecuteOp& op, + absl::Span> op_inputs) + -> xla::StatusOr> { + std::unique_ptr entry; + TF_RETURN_IF_ERROR(proto_lookup->Lookup( + op.computation_handle(), config.core_index_in_replica(), &entry)); + string rendezvous_key_base = std::to_string(op.computation_handle()); + TpuCompilationCacheEntry centry = entry->get(); + const tpu::TpuProgramGroup* tpu_program_group = + tensorflow::down_cast( + centry.tpu_program_group()); + CHECK_NE(tpu_program_group, nullptr); + const int core_index = centry.core_index(); + const TPUExecutableInfoProto& executable = + tpu_program_group->executable_info(core_index); + std::vector input_is_dynamic = GetDynamicInputInfo(executable); + + TF_ASSIGN_OR_RETURN( + xla::HloInputOutputAliasConfig input_output_alias, + GetExecutableAliasConfig(tpu_program_group, backend, core_index)); + TF_ASSIGN_OR_RETURN(std::vector> input_tuples, + GetChainedOpInputs(op, op_inputs, executable)); + auto get_buffers_fn = [&]() { + return GetArgumentsBuffers(input_output_alias, input_tuples, + input_is_dynamic, + /*release_inputs=*/false); + }; + TF_ASSIGN_OR_RETURN( + xla::ExecutionOutput output, + ExecuteTPUProgram(context, node_context.get(), memory_manager.get(), + executable, get_buffers_fn, + config.execution_instance_key(), config.rng_seed(), + tpu_program_group, backend, stream, core_index, + device_ordinal, rendezvous_key_base)); + return AllocateOutputTuple(node_context.get(), stream, input_tuples, + input_output_alias, output.ConsumeResult(), + device_ordinal); + }; + + return ExecuteChained(context, memory_manager, backend, device_ordinal, plan, + config, execute_op); +} + +} // namespace + +REGISTER_KERNEL_BUILDER(Name("XRTExecute") + .Device(DEVICE_TPU_NODE) + .HostMemory("computation_handle") + .HostMemory("execution_config") + .HostMemory("input_handles") + .HostMemory("output_handle"), + XRTExecuteOp); + +REGISTER_KERNEL_BUILDER(Name("XRTExecuteChained") + .Device(DEVICE_TPU_NODE) + .HostMemory("execution_plan") + .HostMemory("execution_config") + .HostMemory("output_handle"), + XRTExecuteChainedOp); + +} // namespace tensorflow