Rollback of [XLA] Allow hlo_rematerialization pass to look through bitcasts when doing rematerialization.

Causes hours-long compilation times on OOMing code using JAX PRNG.

PiperOrigin-RevId: 345611840
Change-Id: I2968f4ffc733c27ae2c37e032954a7d9f93cd818
This commit is contained in:
Marcello Maggioni 2020-12-03 22:58:23 -08:00 committed by TensorFlower Gardener
parent fedf3f45fb
commit 2797009193
3 changed files with 259 additions and 57 deletions

View File

@ -3776,6 +3776,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",

View File

@ -16,10 +16,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_rematerialization.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <set>
#include <string>
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
@ -153,7 +155,22 @@ struct Item {
int64 position;
};
// Data structure meant to record the user of the buffer defined from an Item.
// It records also the operand_number from where such use derives, so that
// indirect uses can be better identified (like for example a buffer used
// through a bitcast).
struct ItemUse {
Item* user;
int64 operand_number;
ItemUse(Item* user, int64 op_num) : user(user), operand_number(op_num) {}
bool operator==(const ItemUse& other) const {
return user == other.user && operand_number == other.operand_number;
}
};
using ItemList = absl::InlinedVector<Item*, 3>;
using UsesList = absl::InlinedVector<ItemUse, 3>;
// Class which maintains an ordered list of instructions with fast insertion
// before arbitrary elements.
@ -412,11 +429,11 @@ class InstructionList {
// has_indirect_users to whether any of the uses is indirect. A use is indirect
// if the instruction defining logical_buffer is not an operand of the use. This
// can happen via buffer aliasing (eg, tuples).
ItemList GetUsers(const InstructionList& instruction_list,
UsesList GetUsers(const InstructionList& instruction_list,
const LogicalBuffer* logical_buffer,
const TuplePointsToAnalysis& points_to_analysis,
bool* has_indirect_users) {
ItemList users;
UsesList users;
// To identify uses iterate through all HloInstruction users of the
// BufferAliases of the logical buffer.
*has_indirect_users = false;
@ -431,14 +448,18 @@ ItemList GetUsers(const InstructionList& instruction_list,
// instruction (the GTE instruction only uses the pointer vector).
continue;
}
if (buffer_alias.instruction() != logical_buffer->instruction()) {
if (buffer_alias.instruction() != logical_buffer->instruction() &&
buffer_alias.instruction()->opcode() != HloOpcode::kBitcast) {
*has_indirect_users = true;
}
// A buffer may be used by the instruction via more than one alias. For
// example, a buffer which appears in more than one element of a tuple.
Item* user_item = instruction_list.GetItem(user);
if (!absl::c_linear_search(users, user_item)) {
users.push_back(user_item);
for (int64 op_idx : user->OperandIndices(buffer_alias.instruction())) {
if (!absl::c_linear_search(
users, ItemUse{user_item, static_cast<int>(op_idx)})) {
users.push_back(ItemUse{user_item, static_cast<int>(op_idx)});
}
}
}
}
@ -516,7 +537,8 @@ class MemoryUsageTracker {
// is remat_item. This method should be called after the HLO graph has
// been transformed (rematerialization instruction created and connected
// to uses).
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
Status AddRematerializedInstruction(Item* original_item, Item* remat_item,
absl::Span<Item*> bitcasts);
// Selects and returns the best candidate instructions for rematerialization.
// A sequence of candidate instructions of length between min_block_size and
@ -538,6 +560,9 @@ class MemoryUsageTracker {
// Returns whether 'item' has any unplaced users.
bool HasUnplacedUsers(Item* item) const;
// Returns the list of uses for a specific 'item'.
const UsesList GetItemUses(Item* item) const;
// Returns whether 'item' is currently in progress.
bool IsInProgressItem(Item* item) const { return item == in_progress_item_; }
@ -588,7 +613,7 @@ class MemoryUsageTracker {
bool has_indirect_uses;
// The instructions which use this buffer.
ItemList users;
UsesList users;
// The number of users (HloInstructions) of this buffer which have not yet
// been placed in the sequence.
@ -611,7 +636,7 @@ class MemoryUsageTracker {
const LogicalBuffer* logical_buffer,
const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
bool has_indirect_uses = false;
ItemList users = GetUsers(instruction_list_, logical_buffer,
UsesList users = GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &has_indirect_uses);
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
logical_buffer->shape(), std::move(users), live_out,
@ -621,13 +646,13 @@ class MemoryUsageTracker {
// Create a new buffer representing a rematerialization of given buffer for
// the given uses.
Buffer& RematerializeBuffer(const Buffer& original_buffer, Item* remat_item,
ItemList&& rematerialized_uses) {
UsesList&& rematerialized_uses) {
CHECK(original_buffer.defining_instruction->placed)
<< original_buffer.defining_instruction->instruction->name();
CHECK(!original_buffer.has_indirect_uses) << original_buffer.ToString();
CHECK(!original_buffer.live_out) << original_buffer.ToString();
for (Item* use : rematerialized_uses) {
CHECK(!use->placed) << use->instruction->name();
for (ItemUse& use : rematerialized_uses) {
CHECK(!use.user->placed) << use.user->instruction->name();
}
return NewBuffer(remat_item, original_buffer.shape,
std::move(rematerialized_uses), /*live_out=*/false,
@ -665,8 +690,6 @@ class MemoryUsageTracker {
return absl::c_linear_search(in_progress_uses, buffer_id);
}
// Returns whether the given buffer is live at the current program
// point.
bool IsCurrentlyLive(BufferId buffer_id) const {
const Buffer& buffer = buffers_[buffer_id];
return (buffer.defining_instruction->placed &&
@ -692,11 +715,18 @@ class MemoryUsageTracker {
// Create a new buffer, add it to buffers_, and return a reference.
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
ItemList&& users, bool live_out, bool has_indirect_uses) {
UsesList&& uses, bool live_out, bool has_indirect_uses) {
int buffer_id = buffers_.size();
auto get_num_of_unique_users = [](const UsesList& uses) -> int64 {
absl::flat_hash_set<Item*> users_set;
for (const ItemUse& use : uses) {
users_set.insert(use.user);
}
return users_set.size();
};
buffers_.push_back(Buffer{
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
has_indirect_uses, users, static_cast<int64>(users.size())});
has_indirect_uses, uses, get_num_of_unique_users(uses)});
return buffers_.back();
}
@ -771,12 +801,15 @@ MemoryUsageTracker::MemoryUsageTracker(
// Add users of while to Buffer users.
bool unused;
for (Item* user_item : GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &unused)) {
if (!absl::c_linear_search(buffer->users, user_item)) {
buffer->users.push_back(user_item);
for (ItemUse& user_item : GetUsers(instruction_list_, logical_buffer,
points_to_analysis, &unused)) {
auto existing_user_it = absl::c_find_if(
buffer->users,
[&](const ItemUse& use) { return user_item.user == use.user; });
if (existing_user_it == buffer->users.end()) {
buffer->unfinished_user_count++;
user_item->buffers_used.push_back(buffer->id);
user_item.user->buffers_used.push_back(buffer->id);
buffer->users.push_back(user_item);
}
}
} else {
@ -784,8 +817,10 @@ MemoryUsageTracker::MemoryUsageTracker(
logical_buffer, points_to_analysis,
ContainsKey(live_out_set, logical_buffer));
item->buffers_defined.push_back(buffer->id);
for (Item* user : buffer->users) {
user->buffers_used.push_back(buffer->id);
for (ItemUse& user : buffer->users) {
if (!absl::c_linear_search(user.user->buffers_used, buffer->id)) {
user.user->buffers_used.push_back(buffer->id);
}
}
}
@ -1003,14 +1038,14 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
// Compressed buffer is now alive.
memory_usage_ += size_function_(compressed_item->instruction->shape());
ItemList placed_users;
ItemList unplaced_users;
UsesList placed_users;
UsesList unplaced_users;
CHECK_EQ(original_item->buffers_output.size(), 1);
BufferId original_buffer_id = original_item->buffers_output[0];
Buffer& original_buffer = buffers_.at(original_buffer_id);
for (Item* user : original_buffer.users) {
if (user->placed) {
CHECK(IsFinished(user)) << user->instruction->name();
for (ItemUse& user : original_buffer.users) {
if (user.user->placed) {
CHECK(IsFinished(user.user)) << user.user->instruction->name();
placed_users.push_back(user);
} else {
unplaced_users.push_back(user);
@ -1018,10 +1053,10 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
}
original_buffer.users = std::move(placed_users);
original_buffer.unfinished_user_count = 0;
original_buffer.users.push_back(compressed_item);
original_buffer.users.push_back(ItemUse{compressed_item, 0});
Buffer& compressed_buffer =
NewBuffer(compressed_item, compressed_item->instruction->shape(),
{uncompressed_item}, /*live_out=*/false,
{ItemUse{uncompressed_item, 0}}, /*live_out=*/false,
/*has_indirect_uses=*/false);
compressed_item->buffers_used = original_item->buffers_output;
compressed_item->buffers_output = {compressed_buffer.id};
@ -1036,8 +1071,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
uncompressed_item->buffers_output = {uncompressed_buffer.id};
uncompressed_item->buffers_defined = {uncompressed_buffer.id};
for (Item* user : uncompressed_buffer.users) {
BufferIdList& buffers_used = user->buffers_used;
for (ItemUse& user : uncompressed_buffer.users) {
BufferIdList& buffers_used = user.user->buffers_used;
std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
uncompressed_buffer.id);
}
@ -1045,8 +1080,8 @@ Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
return Status::OK();
}
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
Item* remat_item) {
Status MemoryUsageTracker::AddRematerializedInstruction(
Item* original_item, Item* remat_item, absl::Span<Item*> bitcasts) {
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
<< original_item->instruction->name()
<< ", remat_instruction = " << remat_item->instruction->name();
@ -1067,9 +1102,23 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
// Buffer used by this instruction was dead, now is alive.
memory_usage_ += AllocatedSize(buffer.id);
}
buffer.unfinished_user_count++;
buffer.users.push_back(remat_item);
absl::InlinedVector<ItemUse, 2> filtered_users;
std::copy_if(buffer.users.begin(), buffer.users.end(),
std::back_inserter(filtered_users),
[&](const ItemUse& iu) { return iu.user == original_item; });
for (ItemUse& u : filtered_users) {
buffer.users.push_back(ItemUse{remat_item, u.operand_number});
}
}
for (Item* bitcast : bitcasts) {
CHECK_EQ(bitcast->instruction->opcode(), HloOpcode::kBitcast);
for (BufferId buffer_id : bitcast->buffers_used) {
Buffer& buffer = buffers_.at(buffer_id);
buffer.unfinished_user_count++;
buffer.users.push_back(ItemUse{bitcast, 0});
}
}
// Create a new set of Buffers defined by the new rematerialization
@ -1078,10 +1127,10 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
for (BufferId old_buffer_id : original_item->buffers_defined) {
Buffer& old_buffer = buffers_.at(old_buffer_id);
ItemList placed_users;
ItemList unplaced_users;
for (Item* user : old_buffer.users) {
if (user->placed) {
UsesList placed_users;
UsesList unplaced_users;
for (ItemUse& user : old_buffer.users) {
if (user.user->placed) {
placed_users.push_back(user);
} else {
unplaced_users.push_back(user);
@ -1097,8 +1146,8 @@ Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
RematerializeBuffer(old_buffer, remat_item, std::move(unplaced_users));
remat_item->buffers_defined.push_back(new_buffer.id);
for (Item* user : new_buffer.users) {
BufferIdList& buffers_used = user->buffers_used;
for (ItemUse& user : new_buffer.users) {
BufferIdList& buffers_used = user.user->buffers_used;
std::replace(buffers_used.begin(), buffers_used.end(), old_buffer_id,
new_buffer.id);
}
@ -1131,6 +1180,10 @@ string MemoryUsageTracker::ToString() const {
absl::StrAppend(&output, " ", buffer.ToString(), live, ", ",
buffer.unfinished_user_count, " unfinished uses\n");
}
absl::StrAppend(&output, " Outputs:\n");
for (BufferId buffer_id : item->buffers_output) {
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
}
absl::StrAppend(&output, " Uses:\n");
for (BufferId buffer_id : item->buffers_used) {
absl::StrAppend(&output, " ", buffers_[buffer_id].ToString(), "\n");
@ -1190,12 +1243,14 @@ bool MemoryUsageTracker::Check() const {
}
for (const Buffer& buffer : buffers_) {
int64 unfinished_uses = 0;
for (Item* user : buffer.users) {
const BufferIdList& used_buffers = user->buffers_used;
absl::flat_hash_set<Item*> already_counted_user;
for (const ItemUse& user : buffer.users) {
const BufferIdList& used_buffers = user.user->buffers_used;
CHECK(absl::c_linear_search(used_buffers, buffer.id))
<< "Instruction " << user->instruction->name()
<< "Instruction " << user.user->instruction->name()
<< " used buffers is missing " << buffer.ToString();
if (!IsFinished(user)) {
if (!IsFinished(user.user) &&
already_counted_user.insert(user.user).second) {
unfinished_uses++;
}
}
@ -1397,8 +1452,8 @@ MemoryUsageTracker::PickRematerializationCandidates(
bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_.at(buffer_id);
for (Item* user : buffer.users) {
if (!user->placed) {
for (const ItemUse& user : buffer.users) {
if (!user.user->placed) {
return true;
}
}
@ -1406,6 +1461,17 @@ bool MemoryUsageTracker::HasUnplacedUsers(Item* item) const {
return false;
}
const UsesList MemoryUsageTracker::GetItemUses(Item* item) const {
UsesList combined_users;
for (BufferId buffer_id : item->buffers_defined) {
const Buffer& buffer = buffers_.at(buffer_id);
for (const ItemUse& user : buffer.users) {
combined_users.push_back(user);
}
}
return combined_users;
}
StatusOr<int64> RematerializeInstructions(
MemoryUsageTracker* memory_tracker, std::vector<Item*>* best_items,
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
@ -1443,18 +1509,30 @@ StatusOr<int64> RematerializeInstructions(
Item* remat_item = instruction_list->CreateItem(remat);
// Replace each remaining use of 'best' with the rematerialization.
std::vector<HloInstruction*> best_users_copy = best->users();
for (HloInstruction* user : best_users_copy) {
if (!memory_tracker->IsPlaced(user)) {
absl::InlinedVector<Item*, 4> bitcasts;
for (auto& user : memory_tracker->GetItemUses(best_item)) {
if (!memory_tracker->IsPlaced(user.user->instruction)) {
VLOG(2) << " Replacing use of " << best->name() << " in "
<< user->name() << " with " << remat->name();
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
<< user.user->instruction->name() << " with " << remat->name();
const int64 op_idx = user.operand_number;
auto* remat_use = remat;
if (user.user->instruction->operand(op_idx)->shape() !=
remat->shape()) {
remat_use = computation->AddInstruction(HloInstruction::CreateUnary(
user.user->instruction->operand(op_idx)->shape(),
HloOpcode::kBitcast, remat));
bitcasts.push_back(instruction_list->CreateItem(remat_use));
bitcasts.back()->buffers_output = remat_item->buffers_defined;
bitcasts.back()->buffers_used = remat_item->buffers_defined;
}
TF_RETURN_IF_ERROR(
user.user->instruction->ReplaceOperandWith(op_idx, remat_use));
}
}
// Account for the rematerialization in the memory tracker.
TF_RETURN_IF_ERROR(
memory_tracker->AddRematerializedInstruction(best_item, remat_item));
TF_RETURN_IF_ERROR(memory_tracker->AddRematerializedInstruction(
best_item, remat_item, absl::MakeSpan(bitcasts)));
// Insert rematerialized instruction right before the earliest unplaced
// use of the instruction *and* the earliest unplaced last use of any
@ -1463,7 +1541,14 @@ StatusOr<int64> RematerializeInstructions(
// this could increase memory usage.
ItemList place_before;
for (auto user : remat->users()) {
place_before.push_back(instruction_list->GetItem(user));
if (!absl::c_linear_search(bitcasts, instruction_list->GetItem(user))) {
place_before.push_back(instruction_list->GetItem(user));
}
}
for (auto* bitcast : bitcasts) {
for (auto user : bitcast->instruction->users()) {
place_before.push_back(instruction_list->GetItem(user));
}
}
for (auto* operand : remat->operands()) {
for (auto* operand_user : operand->users()) {
@ -1486,12 +1571,25 @@ StatusOr<int64> RematerializeInstructions(
}
instruction_list->InsertBeforeInstructions(remat_item, place_before);
for (auto* bitcast : bitcasts) {
instruction_list->InsertBeforeInstructions(bitcast, place_before);
}
// Helper function that looks through bitcasts when determining if there
// is an active user for an HloInstruction.
std::function<bool(HloInstruction*)> uses_empty = [&](HloInstruction* i) {
for (auto* u : i->users()) {
if (u->opcode() != HloOpcode::kBitcast || !uses_empty(u)) {
return false;
}
}
return true;
};
// If the rematerialized instruction is dead then rematerialization is
// essentially a move. Don't delete the instruction now because we don't
// want duplicate HloInstruction* values during the course of the
// transformation because we keep maps with HloInstruction* values as
// keys.
if (best->users().empty()) {
if (uses_empty(best)) {
VLOG(2) << best->name() << " is now dead";
if (ContainsKey(*remat_move_instructions, best)) {
// Previously, 'best' was a rematerialization which killed the
@ -1501,8 +1599,12 @@ StatusOr<int64> RematerializeInstructions(
instruction_list->Denylist(remat);
}
remat_move_instructions->insert(remat);
net_instructions_added += bitcasts.size();
} else {
net_instructions_added++;
net_instructions_added += bitcasts.size() + 1;
}
for (auto* bitcast : bitcasts) {
instruction_list->Denylist(bitcast->instruction);
}
}
VLOG(1) << "Rematerializing instructions ["

View File

@ -748,6 +748,105 @@ ENTRY %entry {
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
}
// Test rematerialization of values through bitcasts
// Its expected that the broadcast gets rematerialized
TEST_F(HloRematerializationTest, ThroughBitcastRemat) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
ENTRY %mycomp (param: f32[1]) -> f32[1] {
%param = f32[1]{0} parameter(0)
%reshape = f32[] reshape(f32[1]{0} %param)
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
%negate = f32[1024,1]{1,0} negate(f32[1024,1]{1,0} %broadcast)
%concatenate = f32[2048,1]{1,0} concatenate(f32[1024,1]{1,0} %negate, f32[1024,1]{1,0} %negate), dimensions={0}
%slice = f32[1,1]{1,0} slice(f32[2048,1]{1,0} %concatenate), slice={[0:1], [0:1]}
%bitcast.1 = f32[1]{0} bitcast(f32[1,1]{1,0} %slice)
%concatenate.1 = f32[1025]{0} concatenate(f32[1024]{0} %bitcast, f32[1]{0} %bitcast.1), dimensions={0}
ROOT %slice.1 = f32[1]{0} slice(f32[1025]{0} %concatenate.1), slice={[0:1]}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto* computation = module->entry_computation();
// Find and save the original broadcast instruction which should be
// rematerialized.
const HloInstruction* slice = computation->root_instruction();
ASSERT_THAT(slice,
op::Slice(op::Concatenate(op::Bitcast(op::Broadcast(_)), _)));
const HloInstruction* concat = slice->operand(0);
const HloInstruction* bcast = concat->operand(0)->operand(0);
// Computation requires 16KB without rematerialization, but uses only 12KB
// with rematerialization so pick a memory limit between these values (14KB).
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/14 * 1024, module.get()));
EXPECT_TRUE(changed);
// Root should not have changed.
EXPECT_EQ(computation->root_instruction(), slice);
// The bitcast for the rematerialized broadcast
const HloInstruction* remat_bitcast = concat->operand(0);
// The broadcast should have been rematerialized.
const HloInstruction* remat_broadcast = remat_bitcast->operand(0);
EXPECT_THAT(remat_broadcast, op::Broadcast(::testing::Ne(bcast)));
// The rematerialized broadcast should be immediately before its bitcast
// and the bitcast before the concatenate in the sequence.
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 2],
concat);
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 3],
remat_bitcast);
EXPECT_EQ(module->schedule()
.sequence(computation)
.instructions()[computation->instruction_count() - 4],
remat_broadcast);
}
// Test that the "deny list for move remats" engages when we rematerialize
// through bitcasts.
TEST_F(HloRematerializationTest, ThroughBitcastRematInfiniteLoop) {
const string& hlo_string = R"(
HloModule fusion, is_scheduled=true
ENTRY %mycomp (param: f32[1]) -> f32[1024] {
%param = f32[1]{0} parameter(0)
%reshape = f32[] reshape(f32[1]{0} %param)
%broadcast = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
%bitcast = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast)
%broadcast2 = f32[1024,1]{1,0} broadcast(f32[] %reshape), dimensions={}
%bitcast2 = f32[1024]{0} bitcast(f32[1024,1]{1,0} %broadcast2)
ROOT %add = f32[1024]{0} add(f32[1024]{0} %bitcast, f32[1024]{0} %bitcast2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto* computation = module->entry_computation();
// Find and save the original broadcasts instruction which should be
// rematerialized.
const HloInstruction* add = computation->root_instruction();
// Run with a low rematerialization limit that cannot be satisfied to make
// sure that we don't get stuck in a loop trying to lower it.
TF_ASSERT_OK_AND_ASSIGN(bool changed,
RunHloRematerialization(
/*memory_limit_bytes=*/1024, module.get()));
ASSERT_THAT(add, op::Add(op::Bitcast(op::Broadcast(_)),
op::Bitcast(op::Broadcast(_))));
EXPECT_TRUE(changed);
}
} // namespace
} // namespace xla