[XLA] Provide a way to test GPU buffer assignment, test buffer assignment for GEMM rewrite
PiperOrigin-RevId: 254607491
This commit is contained in:
parent
d4b6a19ee7
commit
eeda570f69
@ -89,6 +89,10 @@ class GpuExecutable : public Executable {
|
|||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
absl::Span<const ShapedBuffer* const> arguments) override;
|
absl::Span<const ShapedBuffer* const> arguments) override;
|
||||||
|
|
||||||
|
std::shared_ptr<const BufferAssignment> GetBufferAssignment() const {
|
||||||
|
return assignment_;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
StatusOr<ScopedShapedBuffer> Execute(
|
StatusOr<ScopedShapedBuffer> Execute(
|
||||||
const ServiceExecutableRunOptions* run_options,
|
const ServiceExecutableRunOptions* run_options,
|
||||||
|
@ -58,6 +58,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":gpu_codegen_test",
|
":gpu_codegen_test",
|
||||||
"//tensorflow/compiler/xla:debug_options_flags",
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
|
Loading…
x
Reference in New Issue
Block a user