Add TpuExecutableInterface::fingerprint() virtual method.
This also makes the TpuExecutable in tpu_on_demand_compiler.cc subclass TpuExecutableInterface, and implements the fingerprint() method for future use by JAX. I didn't implement it for the TpuExecutable class in tpu_executable.h, since TF doesn't need this functionality (yet?), but it shouldn't be too hard. PiperOrigin-RevId: 330842613 Change-Id: I592068c7b1110e0ae32b241e3e6c5a7b121f3e0f
This commit is contained in:
parent
eb461280fe
commit
d59bdf5493
tensorflow
@ -298,6 +298,7 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/tpu:c_api_decl",
|
"//tensorflow/stream_executor/tpu:c_api_decl",
|
||||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
"//tensorflow/stream_executor/tpu:status_helper",
|
"//tensorflow/stream_executor/tpu:status_helper",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor",
|
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
@ -137,6 +137,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
|
|||||||
TFTPU_SET_FN(executor_fn, TpuCompiler_Compile);
|
TFTPU_SET_FN(executor_fn, TpuCompiler_Compile);
|
||||||
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
|
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
|
||||||
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
|
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
|
||||||
|
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
|
||||||
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
|
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
|
||||||
|
|
||||||
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);
|
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);
|
||||||
|
@ -29,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
#include "tensorflow/stream_executor/tpu/status_helper.h"
|
||||||
|
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||||
@ -97,11 +98,11 @@ void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
class TpuExecutable : public Executable {
|
class TpuExecutable : public TpuExecutableInterface {
|
||||||
public:
|
public:
|
||||||
TpuExecutable(SE_Executable* se_executable,
|
TpuExecutable(SE_Executable* se_executable,
|
||||||
std::shared_ptr<HloModule> hlo_module)
|
std::shared_ptr<HloModule> hlo_module)
|
||||||
: Executable(std::move(hlo_module), nullptr, nullptr),
|
: TpuExecutableInterface(std::move(hlo_module), nullptr, nullptr),
|
||||||
se_executable_(se_executable) {}
|
se_executable_(se_executable) {}
|
||||||
|
|
||||||
~TpuExecutable() override {
|
~TpuExecutable() override {
|
||||||
@ -192,7 +193,31 @@ class TpuExecutable : public Executable {
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::string_view fingerprint() const override {
|
||||||
|
const char* data;
|
||||||
|
size_t size;
|
||||||
|
ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size);
|
||||||
|
return absl::string_view(data, size);
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
Status LoadProgramAndEnqueueToStream(
|
||||||
|
const ServiceExecutableRunOptions& run_options,
|
||||||
|
absl::Span<const stream_executor::DeviceMemoryBase> arguments,
|
||||||
|
stream_executor::DeviceMemoryBase result,
|
||||||
|
absl::optional<stream_executor::DeviceMemoryBase>
|
||||||
|
cross_program_prefetch_addr) override {
|
||||||
|
LOG(FATAL) << "LoadProgramAndEnqueueToStream unimplemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape HostShapeToDeviceShape(const Shape& host_shape) override {
|
||||||
|
LOG(FATAL) << "HostShapeToDeviceShape unimplemented";
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 ShapeSize(const Shape& shape) override {
|
||||||
|
LOG(FATAL) << "ShapeSize unimplemented";
|
||||||
|
}
|
||||||
|
|
||||||
SE_Executable* se_executable_;
|
SE_Executable* se_executable_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -113,4 +113,9 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
|
|||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::string_view TpuExecutable::fingerprint() const {
|
||||||
|
// TODO(skye): the fingerprint can be plumbed through via core_program_
|
||||||
|
LOG(FATAL) << "TpuExecutable::fingerprint() unimplemented";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -46,6 +46,8 @@ class TpuExecutable : public TpuExecutableInterface {
|
|||||||
|
|
||||||
const XLA_TpuProgram* core_program() const { return core_program_; }
|
const XLA_TpuProgram* core_program() const { return core_program_; }
|
||||||
|
|
||||||
|
absl::string_view fingerprint() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Status LoadProgramAndEnqueueToStream(
|
Status LoadProgramAndEnqueueToStream(
|
||||||
const ServiceExecutableRunOptions& run_options,
|
const ServiceExecutableRunOptions& run_options,
|
||||||
|
@ -80,6 +80,8 @@ class TpuExecutableInterface : public Executable {
|
|||||||
absl::optional<stream_executor::DeviceMemoryBase>
|
absl::optional<stream_executor::DeviceMemoryBase>
|
||||||
cross_program_prefetch_addr) = 0;
|
cross_program_prefetch_addr) = 0;
|
||||||
|
|
||||||
|
virtual absl::string_view fingerprint() const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) = 0;
|
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) = 0;
|
||||||
|
|
||||||
|
@ -300,6 +300,10 @@ TFTPU_CAPI_EXPORT void TpuExecutable_ExecuteAsyncOnStream(
|
|||||||
SE_HloExecutionProfile* hlo_execution_profile, SE_ExecutionOutput* output,
|
SE_HloExecutionProfile* hlo_execution_profile, SE_ExecutionOutput* output,
|
||||||
SE_Status* status);
|
SE_Status* status);
|
||||||
|
|
||||||
|
TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
|
||||||
|
const char** fingerprint,
|
||||||
|
size_t* size);
|
||||||
|
|
||||||
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
|
||||||
|
|
||||||
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
|
||||||
@ -445,6 +449,7 @@ struct TfTpu_ExecutorApiFn {
|
|||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
|
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
|
||||||
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
|
||||||
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
|
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
|
||||||
|
|
||||||
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);
|
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);
|
||||||
|
Loading…
Reference in New Issue
Block a user