Open source the XRTTpuDeviceAccessor and register XRTStateOps on TPU on the open source XRT side.
PiperOrigin-RevId: 345582775 Change-Id: Ie97bae4e37b16ba8bce411c016e80b50f7a26ff1
This commit is contained in:
		
							parent
							
								
									1943e58d29
								
							
						
					
					
						commit
						293bd7502b
					
				| @ -41,7 +41,6 @@ cc_library( | ||||
|         "xrt_memory_manager.cc", | ||||
|         "xrt_metrics.cc", | ||||
|         "xrt_state.cc", | ||||
|         "xrt_tpu_device.cc", | ||||
|         "xrt_util.cc", | ||||
|     ], | ||||
|     hdrs = [ | ||||
| @ -51,7 +50,6 @@ cc_library( | ||||
|         "xrt_metrics.h", | ||||
|         "xrt_refptr.h", | ||||
|         "xrt_state.h", | ||||
|         "xrt_tpu_device.h", | ||||
|         "xrt_util.h", | ||||
|     ], | ||||
|     visibility = ["//visibility:public"], | ||||
| @ -78,10 +76,8 @@ cc_library( | ||||
|         "//tensorflow/core:lib_internal", | ||||
|         "//tensorflow/core/platform:regexp", | ||||
|         "//tensorflow/core/profiler/lib:traceme", | ||||
|         "//tensorflow/core/tpu:tpu_configuration", | ||||
|         "//tensorflow/stream_executor", | ||||
|         "//tensorflow/stream_executor:device_memory_allocator", | ||||
|         "//tensorflow/stream_executor/tpu:tpu_node_context", | ||||
|         "@com_google_absl//absl/memory", | ||||
|         "@com_google_absl//absl/synchronization", | ||||
|     ], | ||||
|  | ||||
| @ -73,7 +73,6 @@ cc_library( | ||||
|         "//tensorflow/core:lib", | ||||
|         "//tensorflow/core:lib_internal", | ||||
|         "//tensorflow/core:protos_all_cc", | ||||
|         "//tensorflow/core/tpu:tpu_defs", | ||||
|         "//tensorflow/stream_executor:stream_executor_headers", | ||||
|         "@com_google_absl//absl/strings", | ||||
|     ], | ||||
|  | ||||
| @ -24,7 +24,6 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/tf2xla/xla_op_registry.h" | ||||
| #include "tensorflow/compiler/xla/client/local_client.h" | ||||
| #include "tensorflow/compiler/xrt/xrt_metrics.h" | ||||
| #include "tensorflow/core/tpu/tpu_defs.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| namespace { | ||||
| @ -68,11 +67,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocate") | ||||
|                             .HostMemory("allocation") | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTAllocateOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTAllocate") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("allocation") | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTAllocateOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -82,10 +76,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") | ||||
|                             .Device(DEVICE_XLA_CPU) | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTAllocateUninitializedOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTAllocateUninitialized") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTAllocateUninitializedOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -97,11 +87,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") | ||||
|                             .HostMemory("inputs") | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTAllocateFromTensorOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTAllocateFromTensor") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("inputs") | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTAllocateFromTensorOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -115,12 +100,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") | ||||
|                             .HostMemory("shape_index") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTSubTupleOp<false, XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTSubTuple") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("base_handle") | ||||
|                             .HostMemory("shape_index") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTSubTupleOp<false, XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -134,12 +113,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") | ||||
|                             .HostMemory("shape_index") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTSubTupleOp<true, XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTSubTupleAndRelease") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("base_handle") | ||||
|                             .HostMemory("shape_index") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTSubTupleOp<true, XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -153,12 +126,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") | ||||
|                             .HostMemory("input_handles") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTMakeTupleOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTMakeTuple") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("tuple_description") | ||||
|                             .HostMemory("input_handles") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTMakeTupleOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -170,11 +137,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") | ||||
|                             .HostMemory("handle") | ||||
|                             .HostMemory("literal"), | ||||
|                         XRTReadLiteralOp<false, XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReadLiteral") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("handle") | ||||
|                             .HostMemory("literal"), | ||||
|                         XRTReadLiteralOp<false, XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -188,12 +150,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") | ||||
|                             .HostMemory("literal") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTWriteLiteralOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTWriteLiteral") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("handle") | ||||
|                             .HostMemory("literal") | ||||
|                             .HostMemory("output_handle"), | ||||
|                         XRTWriteLiteralOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -205,11 +161,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") | ||||
|                             .HostMemory("handle") | ||||
|                             .HostMemory("literal"), | ||||
|                         XRTReadLiteralOp<true, XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReadLiteralAndRelease") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("handle") | ||||
|                             .HostMemory("literal"), | ||||
|                         XRTReadLiteralOp<true, XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -221,11 +172,6 @@ REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") | ||||
|                             .HostMemory("handles") | ||||
|                             .HostMemory("tensors"), | ||||
|                         XRTReadToTensorOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReadToTensor") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("handles") | ||||
|                             .HostMemory("tensors"), | ||||
|                         XRTReadToTensorOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") | ||||
|                             .Device(DEVICE_XLA_GPU) | ||||
| @ -235,25 +181,16 @@ REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") | ||||
|                             .Device(DEVICE_XLA_CPU) | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTReleaseAllocationOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllocationHandle") | ||||
|                             .Device(DEVICE_TPU_NODE) | ||||
|                             .HostMemory("handle"), | ||||
|                         XRTReleaseAllocationOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_GPU), | ||||
|                         XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTReleaseAllAllocations").Device(DEVICE_XLA_CPU), | ||||
|                         XRTReleaseAllAllocationsOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER( | ||||
|     Name("XRTReleaseAllAllocations").Device(DEVICE_TPU_NODE), | ||||
|     XRTReleaseAllAllocationsOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_GPU), | ||||
|                         XRTCompactAllocationsOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_XLA_CPU), | ||||
|                         XRTCompactAllocationsOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTCompactAllocations").Device(DEVICE_TPU_NODE), | ||||
|                         XRTCompactAllocationsOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTMetricsCollect").Device(DEVICE_CPU), | ||||
|                         XRTMetricsCollectOp); | ||||
| @ -262,7 +199,5 @@ REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_GPU), | ||||
|                         XRTMemoryInfoOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_XLA_CPU), | ||||
|                         XRTMemoryInfoOp<XRTGenericDeviceAccessor>); | ||||
| REGISTER_KERNEL_BUILDER(Name("XRTMemoryInfo").Device(DEVICE_TPU_NODE), | ||||
|                         XRTMemoryInfoOp<XRTTpuDeviceAccessor>); | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
|  | ||||
| @ -37,7 +37,6 @@ limitations under the License. | ||||
| #include "tensorflow/compiler/xrt/xrt_memory_manager.h" | ||||
| #include "tensorflow/compiler/xrt/xrt_metrics.h" | ||||
| #include "tensorflow/compiler/xrt/xrt_state.h" | ||||
| #include "tensorflow/compiler/xrt/xrt_tpu_device.h" | ||||
| #include "tensorflow/core/common_runtime/dma_helper.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/framework/resource_mgr.h" | ||||
|  | ||||
| @ -1,61 +0,0 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| #include "tensorflow/compiler/xrt/xrt_tpu_device.h" | ||||
| 
 | ||||
| #include "tensorflow/compiler/jit/xla_device.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/framework/resource_mgr.h" | ||||
| #include "tensorflow/core/lib/core/status.h" | ||||
| #include "tensorflow/core/tpu/tpu_configuration.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| /*static*/ Status XRTTpuDeviceAccessor::GetResourceManager(OpKernelContext* ctx, | ||||
|                                                            ResourceMgr** rm) { | ||||
|   // ctx is unused here, but maintained because XRTGenericDeviceAccessor uses
 | ||||
|   // it in its GetResourceManager.
 | ||||
|   *rm = GetTPUConfigResourceMgr(); | ||||
|   if (*rm == nullptr) { | ||||
|     return errors::Internal("No Tpu resource manager."); | ||||
|   } | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status XRTTpuDeviceAccessor::ScopedRef::Acquire(int device_ordinal) { | ||||
|   TF_ASSIGN_OR_RETURN(node_context_, | ||||
|                       tpu::TpuNodeContext::Create(device_ordinal)); | ||||
|   ordinal_ = device_ordinal; | ||||
|   return Status::OK(); | ||||
| } | ||||
| 
 | ||||
| Status XRTTpuDeviceAccessor::ScopedRef::Acquire(OpKernelContext* ctx) { | ||||
|   const XlaDevice::Metadata* metadata; | ||||
|   TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(ctx, &metadata)); | ||||
|   return Acquire(metadata->device_ordinal()); | ||||
| } | ||||
| 
 | ||||
| /*static*/ Status XRTTpuDeviceAccessor::InitScopedRef( | ||||
|     OpKernelContext* /*unused ctx*/, int device_ordinal, | ||||
|     ScopedRef* scoped_ref) { | ||||
|   return scoped_ref->Acquire(device_ordinal); | ||||
| } | ||||
| 
 | ||||
| /*static*/ Status XRTTpuDeviceAccessor::InitScopedRef(OpKernelContext* ctx, | ||||
|                                                       ScopedRef* scoped_ref) { | ||||
|   return scoped_ref->Acquire(ctx); | ||||
| } | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
| @ -1,68 +0,0 @@ | ||||
| /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | ||||
| 
 | ||||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| you may not use this file except in compliance with the License. | ||||
| You may obtain a copy of the License at | ||||
| 
 | ||||
|     http://www.apache.org/licenses/LICENSE-2.0
 | ||||
| 
 | ||||
| Unless required by applicable law or agreed to in writing, software | ||||
| distributed under the License is distributed on an "AS IS" BASIS, | ||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| See the License for the specific language governing permissions and | ||||
| limitations under the License. | ||||
| ==============================================================================*/ | ||||
| 
 | ||||
| // Classes for keeping track of on-device state for TPUs.
 | ||||
| 
 | ||||
| #ifndef TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ | ||||
| #define TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_ | ||||
| 
 | ||||
| #include "tensorflow/compiler/xla/client/local_client.h" | ||||
| #include "tensorflow/core/framework/op_kernel.h" | ||||
| #include "tensorflow/core/framework/resource_mgr.h" | ||||
| #include "tensorflow/stream_executor/tpu/tpu_node_context.h" | ||||
| 
 | ||||
| namespace tensorflow { | ||||
| 
 | ||||
| // This accessor is used for XLA TPU. It uses the distributed TPU compilation
 | ||||
| // cache infrastructure which it accesses via the TPU_SYSTEM resource manager.
 | ||||
| class XRTTpuDeviceAccessor { | ||||
|  public: | ||||
|   static Status GetResourceManager(OpKernelContext* ctx, ResourceMgr** rm); | ||||
| 
 | ||||
|   class ScopedRef { | ||||
|    public: | ||||
|     ScopedRef() {} | ||||
|     ~ScopedRef() {} | ||||
| 
 | ||||
|     ScopedRef(const ScopedRef&) = delete; | ||||
|     ScopedRef& operator=(const ScopedRef&) = delete; | ||||
| 
 | ||||
|     // Returns the XLA device properties from the TpuNodeContext object
 | ||||
|     // protected by this ScopedRef.
 | ||||
|     xla::Backend* backend() { return node_context_->backend(); } | ||||
|     int device_ordinal() { return ordinal_; } | ||||
| 
 | ||||
|    private: | ||||
|     // XRTTpuDeviceAccessor::InitScopedRef is the only way to initialize
 | ||||
|     // ScopedRef.
 | ||||
|     friend class XRTTpuDeviceAccessor; | ||||
| 
 | ||||
|     Status Acquire(int device_ordinal); | ||||
| 
 | ||||
|     Status Acquire(OpKernelContext* ctx); | ||||
| 
 | ||||
|     std::unique_ptr<tpu::TpuNodeContext> node_context_; | ||||
|     int ordinal_ = 0; | ||||
|   }; | ||||
| 
 | ||||
|   static Status InitScopedRef(OpKernelContext* ctx, int device_ordinal, | ||||
|                               ScopedRef* scoped_ref); | ||||
| 
 | ||||
|   static Status InitScopedRef(OpKernelContext* ctx, ScopedRef* scoped_ref); | ||||
| }; | ||||
| 
 | ||||
| }  // namespace tensorflow
 | ||||
| 
 | ||||
| #endif  // TENSORFLOW_COMPILER_XRT_XRT_TPU_DEVICE_H_
 | ||||
| @ -9,7 +9,6 @@ load( | ||||
| package( | ||||
|     default_visibility = [ | ||||
|         "//tensorflow/compiler/tf2xla/kernels:__subpackages__", | ||||
|         "//tensorflow/compiler/xrt:__subpackages__", | ||||
|         "//tensorflow/core/tpu:__subpackages__", | ||||
|         "//tensorflow/stream_executor/tpu:__subpackages__", | ||||
|     ], | ||||
|  | ||||
| @ -5,7 +5,6 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") | ||||
| package( | ||||
|     default_visibility = [ | ||||
|         "//learning/brain/experimental/dtensor:__subpackages__", | ||||
|         "//tensorflow/compiler/xrt:__subpackages__", | ||||
|         "//tensorflow/core/tpu:__subpackages__", | ||||
|     ], | ||||
|     licenses = ["notice"],  # Apache 2.0 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user