[NFC] Eliminate references to HLO instr from Infeed/Outfeed thunks.

- Infeed thunk uses the instruction just for printing, eliminate that use.
- Outfeed think uses input shape, so capture that in an outfeed config object

PiperOrigin-RevId: 335653885
Change-Id: I6ee210ff8012e6a56b47fbabe36f7fb44b81dbd9
This commit is contained in:
Rahul Joshi 2020-10-06 09:17:47 -07:00 committed by TensorFlower Gardener
parent e014c2f458
commit f0702de8c1
5 changed files with 25 additions and 13 deletions

View File

@ -26,14 +26,13 @@ InfeedThunk::InfeedThunk(
ThunkInfo thunk_info,
const ShapeTree<BufferAllocation::Slice>& infeed_slices)
: Thunk(Kind::kInfeed, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
infeed_slices_(infeed_slices) {}
Status InfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& stream = *params.stream;
auto& buffer_allocations = *params.buffer_allocations;
VLOG(2) << "Infeeding to GPU: " << hlo_instruction_->ToString();
VLOG(2) << "Infeeding to GPU";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());

View File

@ -43,7 +43,6 @@ class InfeedThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
const ShapeTree<BufferAllocation::Slice> infeed_slices_;
};

View File

@ -14,26 +14,34 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/outfeed_manager.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
namespace gpu {
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info,
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr) {
OutfeedConfig config;
config.input_shape = instr->operand(0)->shape();
return config;
}
OutfeedThunk::OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
ShapeTree<BufferAllocation::Slice> outfeed_slices)
: Thunk(Kind::kOutfeed, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
config_(std::move(config)),
outfeed_slices_(std::move(outfeed_slices)) {}
Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
auto& stream = *params.stream;
auto& buffer_allocations = *params.buffer_allocations;
VLOG(2) << "Outfeeding from GPU: " << hlo_instruction_->ToString();
VLOG(2) << "Outfeeding from GPU";
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
@ -42,13 +50,12 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) {
outfeed_manager->BlockingGetNextDestination();
// Nothing to be done for empty tuples.
if (ShapeUtil::IsEmptyTuple(hlo_instruction_->operand(0)->shape())) {
if (ShapeUtil::IsEmptyTuple(config_.input_shape)) {
return Status::OK();
}
CHECK(ShapeUtil::Compatible(hlo_instruction_->operand(0)->shape(),
outfeed_buffers->shape()))
CHECK(ShapeUtil::Compatible(config_.input_shape, outfeed_buffers->shape()))
<< "XLA program outfeed request of shape "
<< hlo_instruction_->operand(0)->shape().ToString()
<< config_.input_shape.ToString()
<< " did not match the runtime's outfeed buffer of shape "
<< outfeed_buffers->shape().ToString();

View File

@ -25,6 +25,12 @@ limitations under the License.
namespace xla {
namespace gpu {
struct OutfeedConfig {
Shape input_shape;
};
OutfeedConfig GetOutfeedConfig(const HloInstruction* instr);
// A thunk that outfeeds data. Data must be already resident on the host. This
// thunk performs a host to device copy from the buffer allocated for the
// outfeed op to the host location.
@ -32,7 +38,7 @@ class OutfeedThunk : public Thunk {
public:
// Constructs a OutfeedThunk that copies data to the host-side
// outfeed queue from the buffers in the given shape tree.
OutfeedThunk(ThunkInfo thunk_info,
OutfeedThunk(ThunkInfo thunk_info, OutfeedConfig&& config,
ShapeTree<BufferAllocation::Slice> outfeed_slices);
OutfeedThunk(const OutfeedThunk&) = delete;
@ -41,7 +47,7 @@ class OutfeedThunk : public Thunk {
Status ExecuteOnStream(const ExecuteParams& params) override;
private:
const HloInstruction* hlo_instruction_;
OutfeedConfig config_;
const ShapeTree<BufferAllocation::Slice> outfeed_slices_;
};

View File

@ -193,8 +193,9 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildOutfeedThunk(
*slice = status_or_slice.ValueOrDie();
}
});
OutfeedConfig config = GetOutfeedConfig(inst);
return absl::make_unique<OutfeedThunk>(context_->GetThunkInfo(inst),
std::move(slices));
std::move(config), std::move(slices));
}
Status ThunkEmitter::HandleCustomCall(HloInstruction* custom_call) {