Add TpuExecutableInterface::fingerprint() virtual method.

This also makes the TpuExecutable in tpu_on_demand_compiler.cc
subclass TpuExecutableInterface, and implements the fingerprint()
method for future use by JAX. I didn't implement it for the
TpuExecutable class in tpu_executable.h, since TF doesn't need this
functionality (yet?), but it shouldn't be too hard.

PiperOrigin-RevId: 330842613
Change-Id: I592068c7b1110e0ae32b241e3e6c5a7b121f3e0f
This commit is contained in:
Skye Wanderman-Milne 2020-09-09 18:34:34 -07:00 committed by TensorFlower Gardener
parent eb461280fe
commit d59bdf5493
7 changed files with 43 additions and 2 deletions

View File

@ -298,6 +298,7 @@ cc_library(
"//tensorflow/stream_executor/tpu:c_api_decl",
"//tensorflow/stream_executor/tpu:proto_helper",
"//tensorflow/stream_executor/tpu:status_helper",
"//tensorflow/stream_executor/tpu:tpu_executable_interface",
"//tensorflow/stream_executor/tpu:tpu_executor",
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
"@com_google_absl//absl/types:span",

View File

@ -137,6 +137,7 @@ tensorflow::Status SetExecutorStructFn(void* library_handle) {
TFTPU_SET_FN(executor_fn, TpuCompiler_Compile);
TFTPU_SET_FN(executor_fn, TpuCompiler_ShapeSize);
TFTPU_SET_FN(executor_fn, TpuExecutable_ExecuteAsyncOnStream);
TFTPU_SET_FN(executor_fn, TpuExecutable_Fingerprint);
TFTPU_SET_FN(executor_fn, TpuExecutable_Free);
TFTPU_SET_FN(executor_fn, XlaShapeToTpuShapeRepresentation);

View File

@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/stream_executor/tpu/c_api_decl.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
#include "tensorflow/stream_executor/tpu/status_helper.h"
#include "tensorflow/stream_executor/tpu/tpu_executable_interface.h"
#include "tensorflow/stream_executor/tpu/tpu_executor.h"
#include "tensorflow/stream_executor/tpu/tpu_executor_c_api.h"
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
@ -97,11 +98,11 @@ void XLA_HloModuleConfig_Free(XLA_HloModuleConfig* module_config) {
}
}
class TpuExecutable : public Executable {
class TpuExecutable : public TpuExecutableInterface {
public:
TpuExecutable(SE_Executable* se_executable,
std::shared_ptr<HloModule> hlo_module)
: Executable(std::move(hlo_module), nullptr, nullptr),
: TpuExecutableInterface(std::move(hlo_module), nullptr, nullptr),
se_executable_(se_executable) {}
~TpuExecutable() override {
@ -192,7 +193,31 @@ class TpuExecutable : public Executable {
return output;
}
absl::string_view fingerprint() const override {
const char* data;
size_t size;
ExecutorApiFn()->TpuExecutable_FingerprintFn(se_executable_, &data, &size);
return absl::string_view(data, size);
}
private:
Status LoadProgramAndEnqueueToStream(
const ServiceExecutableRunOptions& run_options,
absl::Span<const stream_executor::DeviceMemoryBase> arguments,
stream_executor::DeviceMemoryBase result,
absl::optional<stream_executor::DeviceMemoryBase>
cross_program_prefetch_addr) override {
LOG(FATAL) << "LoadProgramAndEnqueueToStream unimplemented";
}
Shape HostShapeToDeviceShape(const Shape& host_shape) override {
LOG(FATAL) << "HostShapeToDeviceShape unimplemented";
}
int64 ShapeSize(const Shape& shape) override {
LOG(FATAL) << "ShapeSize unimplemented";
}
SE_Executable* se_executable_;
};

View File

@ -113,4 +113,9 @@ int64 TpuExecutable::ShapeSize(const Shape& shape) {
return size;
}
absl::string_view TpuExecutable::fingerprint() const {
// TODO(skye): the fingerprint can be plumbed through via core_program_
LOG(FATAL) << "TpuExecutable::fingerprint() unimplemented";
}
} // namespace xla

View File

@ -46,6 +46,8 @@ class TpuExecutable : public TpuExecutableInterface {
const XLA_TpuProgram* core_program() const { return core_program_; }
absl::string_view fingerprint() const override;
private:
Status LoadProgramAndEnqueueToStream(
const ServiceExecutableRunOptions& run_options,

View File

@ -80,6 +80,8 @@ class TpuExecutableInterface : public Executable {
absl::optional<stream_executor::DeviceMemoryBase>
cross_program_prefetch_addr) = 0;
virtual absl::string_view fingerprint() const = 0;
protected:
virtual Shape HostShapeToDeviceShape(const Shape& host_shape) = 0;

View File

@ -300,6 +300,10 @@ TFTPU_CAPI_EXPORT void TpuExecutable_ExecuteAsyncOnStream(
SE_HloExecutionProfile* hlo_execution_profile, SE_ExecutionOutput* output,
SE_Status* status);
TFTPU_CAPI_EXPORT void TpuExecutable_Fingerprint(SE_Executable* executable,
const char** fingerprint,
size_t* size);
TFTPU_CAPI_EXPORT void TpuExecutable_Free(SE_Executable*);
// Converts an XLA `Shape` into its equivalent TPU `Shape` representation.
@ -445,6 +449,7 @@ struct TfTpu_ExecutorApiFn {
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_Compile);
TFTPU_ADD_FN_IN_STRUCT(TpuCompiler_ShapeSize);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_ExecuteAsyncOnStream);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Fingerprint);
TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_Free);
TFTPU_ADD_FN_IN_STRUCT(XlaShapeToTpuShapeRepresentation);