Previously, scatter expander would crash when expanding an operation that contains operands with no layout shapes. Since scatter expander runs before layout assignment, it should properly handle shapes without layouts. This CL uses a default layout for such a shape. Add a test case for scatter expander. PiperOrigin-RevId: 256170096
72 lines
2.5 KiB
C++
72 lines
2.5 KiB
C++
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/scatter_expander.h"
|
|
|
|
#include <memory>
|
|
#include <utility>
|
|
|
|
#include "tensorflow/compiler/xla/literal.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/test.h"
|
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
class ScatterExpanderTest : public HloTestBase {};
|
|
|
|
TEST_F(ScatterExpanderTest, ScatterOperandWithoutLayout) {
|
|
const char* kModuleStr = R"(
|
|
HloModule scatter_expander
|
|
|
|
scatter_computation {
|
|
parameter0 = s32[] parameter(0)
|
|
ROOT parameter1 = s32[] parameter(1)
|
|
}
|
|
|
|
ENTRY kernel_entry {
|
|
operand = s32[5] iota(), iota_dimension=0
|
|
indices = s32[1] parameter(0)
|
|
update = s32[] constant(0)
|
|
ROOT scatter = s32[5]{0} scatter(operand, indices, update),
|
|
update_window_dims={}, inserted_window_dims={0},
|
|
scatter_dims_to_operand_dims={0}, index_vector_dim=0,
|
|
to_apply=scatter_computation
|
|
})";
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
ParseAndReturnVerifiedModule(kModuleStr));
|
|
|
|
// The HLO parser changes all no layout shapes from the input to have a
|
|
// default layout, clear the layout of the scatter operand for testing.
|
|
HloInstruction* scatter_operand = FindInstruction(module.get(), "operand");
|
|
scatter_operand->mutable_shape()->clear_layout();
|
|
|
|
ScatterExpander scatter_expander;
|
|
TF_ASSERT_OK_AND_ASSIGN(bool result,
|
|
RunHloPass(&scatter_expander, module.get()));
|
|
EXPECT_TRUE(result);
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace xla
|