Add TpuExecutableInterface::fingerprint() virtual method.
This also makes the TpuExecutable in tpu_on_demand_compiler.cc subclass TpuExecutableInterface, and implements the fingerprint() method for future use by JAX. I didn't implement it for the TpuExecutable class in tpu_executable.h, since TF doesn't need this functionality (yet?), but it shouldn't be too hard. PiperOrigin-RevId: 330842613 Change-Id: I592068c7b1110e0ae32b241e3e6c5a7b121f3e0f
This commit is contained in:
parent
eb461280fe
commit
d59bdf5493
@ -298,6 +298,7 @@ cc_library(
|
||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
"//tensorflow/stream_executor/tpu:status_helper",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -137,6 +137,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_Compile);
|
||||
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
|
||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
|
||||
|
||||
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);
|
||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
@ -97,11 +98,11 @@ void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) {
|
||||
}
|
||||
}
|
||||
|
||||
class TpuExecutable : public Executable {
|
||||
class TpuExecutable : public TpuExecutableInterface {
|
||||
public:
|
||||
TpuExecutable(SE_Executable* se_executable,
|
||||
std::shared_ptr<HloModule> hlo_module)
|
||||
: Executable(std::move(hlo_module), nullptr, nullptr),
|
||||
: TpuExecutableInterface(std::move(hlo_module), nullptr, nullptr),
|
||||
se_executable_(se_executable) {}
|
||||
|
||||
~TpuExecutable() override {
|
||||
@ -192,7 +193,31 @@ class TpuExecutable : public Executable {
|
||||
return output;
|
||||
}
|
||||
|
||||
absl::string_view fingerprint() const override {
|
||||
const char* data;
|
||||
size_t size;
|
||||
ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size);
|
||||
return absl::string_view(data, size);
|
||||
}
|
||||
|
||||
private:
|
||||
Status LoadProgramAndEnqueueToStream(
|
||||
const ServiceExecutableRunOptions& run_options,
|
||||
absl::Span<const stream_executor::DeviceMemoryBase> arguments,
|
||||
stream_executor::DeviceMemoryBase result,
|
||||
absl::optional<stream_executor::DeviceMemoryBase>
|
||||
cross_program_prefetch_addr) override {
|
||||
LOG(FATAL) << "LoadProgramAndEnqueueToStream unimplemented";
|
||||
}
|
||||
|
||||
Shape HostShapeToDeviceShape(const Shape& host_shape) override {
|
||||
LOG(FATAL) << "HostShapeToDeviceShape unimplemented";
|
||||
}
|
||||
|
||||
int64 ShapeSize(const Shape& shape) override {
|
||||
LOG(FATAL) << "ShapeSize unimplemented";
|
||||
}
|
||||
|
||||
SE_Executable* se_executable_;
|
||||
};
|
||||
|
||||
|
@ -113,4 +113,9 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
|
||||
return size;
|
||||
}
|
||||
|
||||
absl::string_view TpuExecutable::fingerprint() const {
|
||||
// TODO(skye): the fingerprint can be plumbed through via core_program_
|
||||
LOG(FATAL) << "TpuExecutable::fingerprint() unimplemented";
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -46,6 +46,8 @@ class TpuExecutable : public TpuExecutableInterface {
|
||||
|
||||
const XLA_TpuProgram* core_program() const { return core_program_; }
|
||||
|
||||
absl::string_view fingerprint() const override;
|
||||
|
||||
private:
|
||||
Status LoadProgramAndEnqueueToStream(
|
||||
const ServiceExecutableRunOptions& run_options,
|
||||
|
@ -80,6 +80,8 @@ class TpuExecutableInterface : public Executable {
|
||||
absl::optional<stream_executor::DeviceMemoryBase>
|
||||
cross_program_prefetch_addr) = 0;
|
||||
|
||||
virtual absl::string_view fingerprint() const = 0;
|
||||
|
||||
protected:
|
||||
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) = 0;
|
||||
|
||||
|
@ -300,6 +300,10 @@ TFTPU_CAPI_EXPORT void TpuExecutable_ExecuteAsyncOnStream(
|
||||
SE_HloExecutionProfile* hlo_execution_profile, SE_ExecutionOutput* output,
|
||||
SE_Status* status);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
|
||||
const char** fingerprint,
|
||||
size_t* size);
|
||||
|
||||
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
||||
|
||||
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
||||
@ -445,6 +449,7 @@ struct TfTpu_ExecutorApiFn {
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
|
||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
|
||||
|
||||
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);
|
||||
|
Loading…
Reference in New Issue
Block a user