STT-tensorflow/tensorflow/compiler/xla/service/dynamic_index_splitter.cc
Michael Kuperstein e25754186b [XLA] Move more XlaBuilder users over to the new DS/DUS API
Move more tests and bridge op implementations over to the new
span-of-scalars API for DynamicSlice and DynamicUpdateSlice.

Also fix a bug in DynamicIndexSplitter to handle the case of an
already-empty index span.

PiperOrigin-RevId: 228952769
2019-01-11 15:14:06 -08:00

100 lines
4.0 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
#include <map>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
namespace xla {
StatusOr<bool> DynamicIndexSplitter::Run(HloModule* module) {
bool changed = false;
std::vector<HloComputation*> computations =
module->MakeNonfusionComputations();
for (HloComputation* computation : computations) {
for (HloInstruction* dynamic_op : computation->MakeInstructionPostOrder()) {
switch (dynamic_op->opcode()) {
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
break;
default:
continue;
}
auto parent = dynamic_op->parent();
bool is_update = dynamic_op->opcode() == HloOpcode::kDynamicUpdateSlice;
int64 num_indices = dynamic_op->operand(0)->shape().rank();
if (num_indices == 0) {
// If the operand rank is 0, directly replace R0 DS/DUS with the
// operand (for DS) or update (for DUS).
if (is_update) {
TF_CHECK_OK(parent->ReplaceInstruction(
dynamic_op, dynamic_op->mutable_operand(1)));
} else {
TF_CHECK_OK(parent->ReplaceInstruction(
dynamic_op, dynamic_op->mutable_operand(0)));
}
changed = true;
continue;
}
int64 index_operand_number = Cast<HloDynamicIndexInstruction>(dynamic_op)
->first_index_operand_number();
auto index_operand = dynamic_op->mutable_operand(index_operand_number);
if (ShapeUtil::IsScalar(index_operand->shape())) {
// This DS/DUS already uses scalar indices.
continue;
}
TF_RET_CHECK(index_operand->shape().rank() == 1);
auto index_element_type = index_operand->shape().element_type();
std::vector<HloInstruction*> index_array;
for (int64 dim = 0; dim < num_indices; ++dim) {
auto slice = parent->AddInstruction(HloInstruction::CreateSlice(
ShapeUtil::MakeShape(index_element_type, {1}), index_operand, {dim},
{dim + 1}, {1}));
auto bitcast = parent->AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(index_element_type, {}), slice));
index_array.push_back(bitcast);
}
auto new_dynamic_op =
is_update
? HloInstruction::CreateDynamicUpdateSlice(
dynamic_op->shape(), dynamic_op->mutable_operand(0),
dynamic_op->mutable_operand(1), absl::MakeSpan(index_array))
: HloInstruction::CreateDynamicSlice(
dynamic_op->shape(), dynamic_op->mutable_operand(0),
absl::MakeSpan(index_array),
dynamic_op->dynamic_slice_sizes());
TF_CHECK_OK(parent->ReplaceWithNewInstruction(dynamic_op,
std::move(new_dynamic_op)));
changed = true;
}
}
return changed;
}
} // namespace xla