[NFC] Eliminate references to HLO instr from Infeed/Outfeed thunks.
- Infeed thunk uses the instruction just for printing, eliminate that use. - Outfeed think uses input shape, so capture that in an outfeed config object PiperOrigin-RevId: 335653885 Change-Id: I6ee210ff8012e6a56b47fbabe36f7fb44b81dbd9
This commit is contained in:
parent
e014c2f458
commit
f0702de8c1
@ -26,14 +26,13 @@ InfeedThunk::InfeedThunk(
|
||||
ThunkInfo thunk_info,
|
||||
const ShapeTree<BufferAllocation::Slice>& infeed_slices)
|
||||
: Thunk(Kind::kInfeed, thunk_info),
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
infeed_slices_(infeed_slices) {}
|
||||
|
||||
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& stream = *params.stream;
|
||||
auto& buffer_allocations = *params.buffer_allocations;
|
||||
|
||||
VLOG(2) << "Infeeding to GPU: " << hlo_instruction_->ToString();
|
||||
VLOG(2) << "Infeeding to GPU";
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
@ -43,7 +43,6 @@ class InfeedThunk : public Thunk {
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
private:
|
||||
const HloInstruction* hlo_instruction_;
|
||||
const ShapeTree<BufferAllocation::Slice> infeed_slices_;
|
||||
};
|
||||
|
||||
|
||||
@ -14,26 +14,34 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
|
||||
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) {
|
||||
OutfeedConfig config;
|
||||
config.input_shape = instr->operand(0)->shape();
|
||||
return config;
|
||||
}
|
||||
|
||||
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices)
|
||||
: Thunk(Kind::kOutfeed, thunk_info),
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
config_(std::move(config)),
|
||||
outfeed_slices_(std::move(outfeed_slices)) {}
|
||||
|
||||
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto& stream = *params.stream;
|
||||
auto& buffer_allocations = *params.buffer_allocations;
|
||||
|
||||
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction_->ToString();
|
||||
VLOG(2) << "Outfeeding from GPU";
|
||||
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
@ -42,13 +50,12 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
outfeed_manager->BlockingGetNextDestination();
|
||||
|
||||
// Nothing to be done for empty tuples.
|
||||
if (ShapeUtil::IsEmptyTuple(hlo_instruction_->operand(0)->shape())) {
|
||||
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
CHECK(ShapeUtil::Compatible(hlo_instruction_->operand(0)->shape(),
|
||||
outfeed_buffers->shape()))
|
||||
CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()))
|
||||
<< "XLA program outfeed request of shape "
|
||||
<< hlo_instruction_->operand(0)->shape().ToString()
|
||||
<< config_.input_shape.ToString()
|
||||
<< " did not match the runtime's outfeed buffer of shape "
|
||||
<< outfeed_buffers->shape().ToString();
|
||||
|
||||
|
||||
@ -25,6 +25,12 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
struct OutfeedConfig {
|
||||
Shape input_shape;
|
||||
};
|
||||
|
||||
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr);
|
||||
|
||||
// A thunk that outfeeds data. Data must be already resident on the host. This
|
||||
// thunk performs a host to device copy from the buffer allocated for the
|
||||
// outfeed op to the host location.
|
||||
@ -32,7 +38,7 @@ class OutfeedThunk : public Thunk {
|
||||
public:
|
||||
// Constructs a OutfeedThunk that copies data to the host-side
|
||||
// outfeed queue from the buffers in the given shape tree.
|
||||
OutfeedThunk(ThunkInfo thunk_info,
|
||||
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
|
||||
ShapeTree<BufferAllocation::Slice> outfeed_slices);
|
||||
|
||||
OutfeedThunk(const OutfeedThunk&) = delete;
|
||||
@ -41,7 +47,7 @@ class OutfeedThunk : public Thunk {
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
private:
|
||||
const HloInstruction* hlo_instruction_;
|
||||
OutfeedConfig config_;
|
||||
const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
|
||||
};
|
||||
|
||||
|
||||
@ -193,8 +193,9 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
|
||||
*slice = status_or_slice.ValueOrDie();
|
||||
}
|
||||
});
|
||||
OutfeedConfig config = GetOutfeedConfig(inst);
|
||||
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
|
||||
std::move(slices));
|
||||
std::move(config), std::move(slices));
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user