Add TpuExecutableInterface::fingerprint() virtual method.
This also makes the TpuExecutable in tpu_on_demand_compiler.cc subclass TpuExecutableInterface, and implements the fingerprint() method for future use by JAX. I didn't implement it for the TpuExecutable class in tpu_executable.h, since TF doesn't need this functionality (yet?), but it shouldn't be too hard. PiperOrigin-RevId: 330842613 Change-Id: I592068c7b1110e0ae32b241e3e6c5a7b121f3e0f
This commit is contained in:
		
							parent
							
								
									eb461280fe
								
							
						
					
					
						commit
						d59bdf5493
					
				| @ -298,6 +298,7 @@ cc_library( | ||||
|         "//tensorflow/stream_executor/tpu:c_api_decl", | ||||
|         "//tensorflow/stream_executor/tpu:proto_helper", | ||||
|         "//tensorflow/stream_executor/tpu:status_helper", | ||||
|         "//tensorflow/stream_executor/tpu:tpu_executable_interface", | ||||
|         "//tensorflow/stream_executor/tpu:tpu_executor", | ||||
|         "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", | ||||
|         "@com_google_absl//absl/types:span", | ||||
|  | ||||
| @ -137,6 +137,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) { | ||||
|   TFTPU_SET_FN(executor_fn, TpuCompiler_Compile); | ||||
|   TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize); | ||||
|   TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream); | ||||
|   TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint); | ||||
|   TFTPU_SET_FN(executor_fn, TpuExecutable_Free); | ||||
| 
 | ||||
|   TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation); | ||||
|  | ||||
| @ -29,6 +29,7 @@ limitations under the License. | ||||
| #include "tensorflow/stream_executor/tpu/c_api_decl.h" | ||||
| #include "tensorflow/stream_executor/tpu/proto_helper.h" | ||||
| #include "tensorflow/stream_executor/tpu/status_helper.h" | ||||
| #include "tensorflow/stream_executor/tpu/tpu_executable_interface.h" | ||||
| #include "tensorflow/stream_executor/tpu/tpu_executor.h" | ||||
| #include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h" | ||||
| #include "tensorflow/stream_executor/tpu/tpu_platform.h" | ||||
| @ -97,11 +98,11 @@ void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) { | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| class TpuExecutable : public Executable { | ||||
| class TpuExecutable : public TpuExecutableInterface { | ||||
|  public: | ||||
|   TpuExecutable(SE_Executable* se_executable, | ||||
|                 std::shared_ptr<HloModule> hlo_module) | ||||
|       : Executable(std::move(hlo_module), nullptr, nullptr), | ||||
|       : TpuExecutableInterface(std::move(hlo_module), nullptr, nullptr), | ||||
|         se_executable_(se_executable) {} | ||||
| 
 | ||||
|   ~TpuExecutable() override { | ||||
| @ -192,7 +193,31 @@ class TpuExecutable : public Executable { | ||||
|     return output; | ||||
|   } | ||||
| 
 | ||||
|   absl::string_view fingerprint() const override { | ||||
|     const char* data; | ||||
|     size_t size; | ||||
|     ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size); | ||||
|     return absl::string_view(data, size); | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   Status LoadProgramAndEnqueueToStream( | ||||
|       const ServiceExecutableRunOptions& run_options, | ||||
|       absl::Span<const stream_executor::DeviceMemoryBase> arguments, | ||||
|       stream_executor::DeviceMemoryBase result, | ||||
|       absl::optional<stream_executor::DeviceMemoryBase> | ||||
|           cross_program_prefetch_addr) override { | ||||
|     LOG(FATAL) << "LoadProgramAndEnqueueToStream unimplemented"; | ||||
|   } | ||||
| 
 | ||||
|   Shape HostShapeToDeviceShape(const Shape& host_shape) override { | ||||
|     LOG(FATAL) << "HostShapeToDeviceShape unimplemented"; | ||||
|   } | ||||
| 
 | ||||
|   int64 ShapeSize(const Shape& shape) override { | ||||
|     LOG(FATAL) << "ShapeSize unimplemented"; | ||||
|   } | ||||
| 
 | ||||
|   SE_Executable* se_executable_; | ||||
| }; | ||||
| 
 | ||||
|  | ||||
| @ -113,4 +113,9 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) { | ||||
|   return size; | ||||
| } | ||||
| 
 | ||||
| absl::string_view TpuExecutable::fingerprint() const { | ||||
|   // TODO(skye): the fingerprint can be plumbed through via core_program_
 | ||||
|   LOG(FATAL) << "TpuExecutable::fingerprint() unimplemented"; | ||||
| } | ||||
| 
 | ||||
| }  // namespace xla
 | ||||
|  | ||||
| @ -46,6 +46,8 @@ class TpuExecutable : public TpuExecutableInterface { | ||||
| 
 | ||||
|   const XLA_TpuProgram* core_program() const { return core_program_; } | ||||
| 
 | ||||
|   absl::string_view fingerprint() const override; | ||||
| 
 | ||||
|  private: | ||||
|   Status LoadProgramAndEnqueueToStream( | ||||
|       const ServiceExecutableRunOptions& run_options, | ||||
|  | ||||
| @ -80,6 +80,8 @@ class TpuExecutableInterface : public Executable { | ||||
|       absl::optional<stream_executor::DeviceMemoryBase> | ||||
|           cross_program_prefetch_addr) = 0; | ||||
| 
 | ||||
|   virtual absl::string_view fingerprint() const = 0; | ||||
| 
 | ||||
|  protected: | ||||
|   virtual Shape HostShapeToDeviceShape(const Shape& host_shape) = 0; | ||||
| 
 | ||||
|  | ||||
| @ -300,6 +300,10 @@ TFTPU_CAPI_EXPORT void TpuExecutable_ExecuteAsyncOnStream( | ||||
|     SE_HloExecutionProfile* hlo_execution_profile, SE_ExecutionOutput* output, | ||||
|     SE_Status* status); | ||||
| 
 | ||||
| TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable, | ||||
|                                                  const char** fingerprint, | ||||
|                                                  size_t* size); | ||||
| 
 | ||||
| TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*); | ||||
| 
 | ||||
| // Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
 | ||||
| @ -445,6 +449,7 @@ struct TfTpu_ExecutorApiFn { | ||||
|   TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile); | ||||
|   TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize); | ||||
|   TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream); | ||||
|   TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint); | ||||
|   TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free); | ||||
| 
 | ||||
|   TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user