[XLA] Add xla builder support for variadic reduce window op. This is the first CL leading to full support of variaduc reduce window.
PiperOrigin-RevId: 338417352 Change-Id: I8e5907f0ddf2a29081c4d84d593b30f5c3eda6ed
This commit is contained in:
parent
abfdb66c8e
commit
96bd0a0b68
@ -2324,31 +2324,53 @@ XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
Padding padding) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
|
||||
window_dimensions, window_strides));
|
||||
return ReduceWindow(absl::MakeSpan(&operand, 1),
|
||||
absl::MakeSpan(&init_value, 1), computation,
|
||||
window_dimensions, window_strides, padding);
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
Padding padding) {
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
const Shape* operand_shape = nullptr;
|
||||
for (const auto& operand : operands) {
|
||||
TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
|
||||
window_dimensions, window_strides));
|
||||
}
|
||||
CHECK(operand_shape != nullptr);
|
||||
std::vector<std::pair<int64, int64>> padding_values =
|
||||
MakePadding(AsInt64Slice(operand_shape->dimensions()),
|
||||
window_dimensions, window_strides, padding);
|
||||
return ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, computation, window_dimensions, window_strides,
|
||||
operands, init_values, computation, window_dimensions, window_strides,
|
||||
/*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
|
||||
XlaOp operand, XlaOp init_value, const XlaComputation& computation,
|
||||
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
std::vector<const Shape*> operand_shapes, init_shapes;
|
||||
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
const auto& operand = operands[i];
|
||||
const auto& init_value = init_values[i];
|
||||
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
|
||||
operand_shapes.push_back(operand_shape);
|
||||
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
|
||||
init_shapes.push_back(init_shape);
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
|
||||
computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(auto window,
|
||||
@ -2358,12 +2380,33 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
|
||||
/*rhs_dilation=*/window_dilations));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape shape, ShapeInference::InferReduceWindowShape(
|
||||
*operand_shape, *init_shape, window, to_apply_shape));
|
||||
return ReduceWindowInternal(shape, operand, init_value, computation,
|
||||
absl::MakeSpan(operand_shapes),
|
||||
absl::MakeSpan(init_shapes), window, to_apply_shape));
|
||||
return ReduceWindowInternal(shape, operands, init_values, computation,
|
||||
std::move(window));
|
||||
});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
|
||||
const Shape& shape, absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values, const XlaComputation& computation,
|
||||
Window window) {
|
||||
if (operands.size() == 1) {
|
||||
return ReduceWindowInternal(shape, operands[0], init_values[0], computation,
|
||||
window);
|
||||
} else {
|
||||
HloInstructionProto instr;
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
*instr.mutable_window() = std::move(window);
|
||||
AddCalledComputation(computation, &instr);
|
||||
std::vector<XlaOp> args;
|
||||
args.insert(args.end(), operands.begin(), operands.end());
|
||||
args.insert(args.end(), init_values.begin(), init_values.end());
|
||||
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
|
||||
absl::MakeSpan(args));
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation, Window window) {
|
||||
@ -4067,6 +4110,17 @@ XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
|
||||
padding);
|
||||
}
|
||||
|
||||
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding) {
|
||||
CHECK(!operands.empty());
|
||||
return operands[0].builder()->ReduceWindow(operands, init_values, computation,
|
||||
window_dimensions, window_strides,
|
||||
padding);
|
||||
}
|
||||
|
||||
XlaOp ReduceWindowWithGeneralPadding(
|
||||
const XlaOp operand, const XlaOp init_value,
|
||||
const XlaComputation& computation,
|
||||
@ -4076,8 +4130,9 @@ XlaOp ReduceWindowWithGeneralPadding(
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding) {
|
||||
return operand.builder()->ReduceWindowWithGeneralPadding(
|
||||
operand, init_value, computation, window_dimensions, window_strides,
|
||||
base_dilations, window_dilations, padding);
|
||||
absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
|
||||
window_dimensions, window_strides, base_dilations, window_dilations,
|
||||
padding);
|
||||
}
|
||||
|
||||
XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,
|
||||
|
@ -648,18 +648,28 @@ class XlaBuilder {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding);
|
||||
|
||||
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding);
|
||||
|
||||
XlaOp ReduceWindowWithGeneralPadding(
|
||||
XlaOp operand, XlaOp init_value, const XlaComputation& computation,
|
||||
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
absl::Span<const int64> base_dilations,
|
||||
absl::Span<const int64> window_dilations,
|
||||
absl::Span<const std::pair<int64, int64>> padding);
|
||||
|
||||
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape,
|
||||
absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
Window window);
|
||||
virtual StatusOr<XlaOp> ReduceWindowInternal(
|
||||
const Shape& shape, XlaOp operand, XlaOp init_value,
|
||||
const XlaComputation& computation, Window window);
|
||||
|
||||
XlaOp CrossReplicaSum(XlaOp operand,
|
||||
absl::Span<const ReplicaGroup> replica_groups = {});
|
||||
|
||||
@ -1137,6 +1147,12 @@ class XlaBuilder {
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
Padding padding);
|
||||
friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides,
|
||||
Padding padding);
|
||||
friend XlaOp ReduceWindowWithGeneralPadding(
|
||||
XlaOp operand, XlaOp init_value, const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
@ -1965,6 +1981,12 @@ XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding);
|
||||
|
||||
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
|
||||
absl::Span<const XlaOp> init_values,
|
||||
const XlaComputation& computation,
|
||||
absl::Span<const int64> window_dimensions,
|
||||
absl::Span<const int64> window_strides, Padding padding);
|
||||
|
||||
// As ReduceWindow(), but the padding is given in the format
|
||||
// returned by MakePadding().
|
||||
XlaOp ReduceWindowWithGeneralPadding(
|
||||
|
@ -873,6 +873,8 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) {
|
||||
ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4},
|
||||
/*window_strides=*/{1, 1, 1}, Padding::kValid);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
VLOG(2) << module->entry_computation()->root_instruction()->ToString()
|
||||
<< "\n";
|
||||
const Shape& result_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
EXPECT_TRUE(
|
||||
@ -880,6 +882,46 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) {
|
||||
<< result_shape;
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) {
|
||||
XlaBuilder b(TestName());
|
||||
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
|
||||
ShapeUtil::MakeShape(U32, {})});
|
||||
auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
|
||||
auto p1 = Parameter(&b, 1, tuple_param_shape, "p1");
|
||||
ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
|
||||
/*dynamic_size_param_index=*/{1},
|
||||
/*target_param_num=*/0,
|
||||
/*target_param_index=*/{0},
|
||||
/*target_dim_num=*/0));
|
||||
auto gte0 = GetTupleElement(p0, 0);
|
||||
auto gte1 = GetTupleElement(p1, 0);
|
||||
std::vector<XlaOp> input_operands = {gte0, gte1};
|
||||
XlaBuilder bsum(TestName());
|
||||
auto p2 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x0");
|
||||
auto p3 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "x1");
|
||||
auto p4 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "y0");
|
||||
auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1");
|
||||
std::vector<XlaOp> output_operands = {Add(p2, p4), Add(p3, p5)};
|
||||
Tuple(&bsum, absl::MakeSpan(output_operands));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
|
||||
auto init = ConstantR0<float>(&b, 0.f);
|
||||
ReduceWindow(input_operands, {init, init}, sum,
|
||||
/*window_dimensions=*/{1, 2, 4},
|
||||
/*window_strides=*/{1, 1, 1}, Padding::kValid);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
|
||||
VLOG(2) << module->entry_computation()->root_instruction()->ToString()
|
||||
<< "\n";
|
||||
const Shape& result_shape =
|
||||
module->entry_computation()->root_instruction()->shape();
|
||||
EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(0).dynamic_dimensions(),
|
||||
{true, false, false}))
|
||||
<< result_shape.tuple_shapes(0);
|
||||
EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(1).dynamic_dimensions(),
|
||||
{true, false, false}))
|
||||
<< result_shape.tuple_shapes(1);
|
||||
}
|
||||
|
||||
TEST_F(XlaBuilderTest, DynamicSelectAndScatter) {
|
||||
XlaBuilder b(TestName());
|
||||
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
|
||||
|
@ -241,7 +241,10 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
|
@ -2540,6 +2540,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
|
||||
// if I in bounds of input
|
||||
// value = function(value, input[I])
|
||||
// output[O] = value
|
||||
if (reduce_window->shape().IsTuple()) {
|
||||
return Status(tensorflow::error::UNIMPLEMENTED,
|
||||
"Variadic reduce window op is not yet fully supported.");
|
||||
}
|
||||
const HloInstruction* operand = reduce_window->operand(0);
|
||||
const Window& window = reduce_window->window();
|
||||
|
||||
|
@ -1932,6 +1932,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
|
||||
}
|
||||
|
||||
Status HandleReduceWindow(HloInstruction* reduce_window) override {
|
||||
if (reduce_window->shape().IsTuple()) {
|
||||
return Status(tensorflow::error::UNIMPLEMENTED,
|
||||
"Variadic reduce window op is not yet fully supported.");
|
||||
}
|
||||
auto operand = reduce_window->operand(0);
|
||||
const Window& window = reduce_window->window();
|
||||
HloComputation* function = reduce_window->to_apply();
|
||||
|
@ -515,11 +515,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReduceWindow:
|
||||
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
|
||||
<< "Reduce window should have an even number of operands but "
|
||||
"sees "
|
||||
<< proto.operand_ids_size();
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
|
||||
<< "ReduceWindow should have 1 called computation but sees "
|
||||
<< proto.called_computation_ids_size();
|
||||
instruction = CreateReduceWindow(shape, operands(0), operands(1),
|
||||
proto.window(), computations(0));
|
||||
{
|
||||
const auto reduce_operands = all_operands();
|
||||
auto inputs = absl::MakeSpan(reduce_operands)
|
||||
.subspan(0, reduce_operands.size() / 2);
|
||||
auto init_values =
|
||||
absl::MakeSpan(reduce_operands)
|
||||
.subspan(reduce_operands.size() / 2, reduce_operands.size());
|
||||
instruction = CreateReduceWindow(shape, inputs, init_values,
|
||||
proto.window(), computations(0));
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
|
||||
@ -1273,6 +1285,13 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
|
||||
shape, operand, init_value, window, reduce_computation);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloInstruction* const> init_values, const Window& window,
|
||||
HloComputation* reduce_computation) {
|
||||
return absl::make_unique<HloReduceWindowInstruction>(
|
||||
shape, operands, init_values, window, reduce_computation);
|
||||
}
|
||||
/* static */ std::unique_ptr<HloInstruction>
|
||||
HloInstruction::CreateBatchNormTraining(const Shape& shape,
|
||||
HloInstruction* operand,
|
||||
|
@ -830,6 +830,16 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
|
||||
const Window& window, HloComputation* reduce_computation);
|
||||
|
||||
// A more general, multiple-argument version of the above.
|
||||
// The reduce_computation being applied,now takes N arguments:
|
||||
// [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
|
||||
// valueN], and returns an N-tuple. The operands and init_values now each
|
||||
// contain a span of N input arrays and n initial values.
|
||||
static std::unique_ptr<HloInstruction> CreateReduceWindow(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloInstruction* const> init_values, const Window& window,
|
||||
HloComputation* reduce_computation);
|
||||
|
||||
// Creates a batch-norm-training instruction.
|
||||
static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
|
||||
|
@ -2237,9 +2237,21 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
|
||||
HloReduceWindowInstruction::HloReduceWindowInstruction(
|
||||
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
|
||||
const Window& window, HloComputation* reduce_computation)
|
||||
: HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
|
||||
absl::MakeSpan(&init_value, 1), window,
|
||||
reduce_computation) {}
|
||||
|
||||
HloReduceWindowInstruction::HloReduceWindowInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloInstruction* const> init_values, const Window& window,
|
||||
HloComputation* reduce_computation)
|
||||
: HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
|
||||
AppendOperand(operand);
|
||||
AppendOperand(init_value);
|
||||
for (auto* operand : operands) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
for (auto* init_value : init_values) {
|
||||
AppendOperand(init_value);
|
||||
}
|
||||
AppendComputation(reduce_computation);
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
@ -1294,10 +1295,43 @@ class HloReduceWindowInstruction : public HloInstruction {
|
||||
HloInstruction* init_value,
|
||||
const Window& window,
|
||||
HloComputation* reduce_computation);
|
||||
explicit HloReduceWindowInstruction(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> operands,
|
||||
absl::Span<HloInstruction* const> init_values, const Window& window,
|
||||
HloComputation* reduce_computation);
|
||||
const Window& window() const override { return window_; }
|
||||
void set_window(const Window& window) override { window_ = window; }
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
// Returns the number of input arrays (and, consequentially, the number of
|
||||
// init values) this reduce has.
|
||||
int64 input_count() const { return operand_count() / 2; }
|
||||
// Returns the input tensors to be reduced.
|
||||
absl::Span<HloInstruction* const> input_arrays() const {
|
||||
return absl::MakeSpan(operands()).subspan(0, input_count());
|
||||
}
|
||||
// Returns the init values of the reduction.
|
||||
absl::Span<HloInstruction* const> init_values() const {
|
||||
return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
|
||||
}
|
||||
// Returns the shapes of input tensors to be reduced.
|
||||
absl::InlinedVector<const Shape*, 2> input_array_shapes() const {
|
||||
absl::InlinedVector<const Shape*, 2> shapes;
|
||||
for (const auto* op : input_arrays()) {
|
||||
VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
|
||||
shapes.push_back(&op->shape());
|
||||
VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
// Returns the init values of the reduction.
|
||||
absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
|
||||
absl::InlinedVector<const Shape*, 2> shapes;
|
||||
for (const auto* op : init_values()) {
|
||||
shapes.push_back(&op->shape());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
@ -1310,6 +1344,7 @@ class HloReduceWindowInstruction : public HloInstruction {
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
Window window_;
|
||||
};
|
||||
|
||||
|
@ -119,7 +119,7 @@ namespace xla {
|
||||
V(kRecvDone, "recv-done", 1) \
|
||||
V(kReduce, "reduce", kHloOpcodeIsVariadic) \
|
||||
V(kReducePrecision, "reduce-precision", 1) \
|
||||
V(kReduceWindow, "reduce-window", 2) \
|
||||
V(kReduceWindow, "reduce-window", kHloOpcodeIsVariadic) \
|
||||
V(kRemainder, "remainder", 2) \
|
||||
V(kReplicaId, "replica-id", 0) \
|
||||
V(kReshape, "reshape", 1) \
|
||||
|
@ -65,6 +65,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kReduceWindow:
|
||||
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
|
||||
break;
|
||||
default:
|
||||
|
@ -2084,7 +2084,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
arg_shapes.size());
|
||||
}
|
||||
int64 num_reduced_args = arg_shapes.size() / 2;
|
||||
|
||||
auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
|
||||
// Check that all of the reduced tensors have the same dimensions. The element
|
||||
// types may be different.
|
||||
@ -2097,7 +2096,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
ShapeUtil::HumanString(*reduced_args[i]));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the dimensions to reduce are in-bounds for the given shape.
|
||||
// We've already verified all reduced tensors have the same dimensions, so it
|
||||
// doesn't matter which one we choose.
|
||||
@ -2156,6 +2154,43 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
return InferReduceWindowShape(operand_shape, init_value_shape, window);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
|
||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||
const Window& window, const ProgramShape& to_apply_shape) {
|
||||
auto number_of_input = operands.size();
|
||||
// Check that all of the reduced tensors have the same dimensions. The element
|
||||
// types may be different.
|
||||
for (int64 i = 1; i < number_of_input; ++i) {
|
||||
if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
|
||||
return InvalidArgument(
|
||||
"All reduced tensors must have the same dimension. Tensor 0 has "
|
||||
"shape %s, Tensor %d has shape %s",
|
||||
ShapeUtil::HumanString(*operands[0]), i,
|
||||
ShapeUtil::HumanString(*operands[i]));
|
||||
}
|
||||
}
|
||||
std::vector<PrimitiveType> operand_element_type_vec;
|
||||
for (const Shape* s : operands) {
|
||||
operand_element_type_vec.push_back(s->element_type());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
|
||||
operand_element_type_vec,
|
||||
/*inputs=*/number_of_input));
|
||||
std::vector<Shape> output_shape_vec;
|
||||
for (int i = 0; i < operands.size(); ++i) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cur_output_shape,
|
||||
InferReduceWindowShape(*operands[i], *init_values[i], window));
|
||||
output_shape_vec.push_back(cur_output_shape);
|
||||
}
|
||||
if (ShapeUtil::IsScalar(to_apply_shape.result())) {
|
||||
CHECK_EQ(output_shape_vec.size(), 1);
|
||||
return output_shape_vec[0];
|
||||
} else {
|
||||
return ShapeUtil::MakeTupleShape(output_shape_vec);
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
|
||||
const Shape& operand_shape, const Shape& init_value_shape,
|
||||
const Window& window) {
|
||||
|
@ -164,10 +164,16 @@ class ShapeInference {
|
||||
static StatusOr<Shape> InferReduceWindowShape(
|
||||
const Shape& operand_shape, const Shape& init_value, const Window& window,
|
||||
const ProgramShape& to_apply_shape);
|
||||
|
||||
static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
|
||||
const Shape& init_value,
|
||||
const Window& window);
|
||||
static StatusOr<Shape> InferReduceWindowShape(
|
||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||
const Window& window, const ProgramShape& to_apply_shape);
|
||||
|
||||
static StatusOr<Shape> InferReduceWindowShape(
|
||||
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
|
||||
const Window& window);
|
||||
|
||||
// Infers the shape produced by scattering the given source shape to the
|
||||
// selected indices of each window on the operand shape.
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/client/padding.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
@ -912,6 +913,32 @@ TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
|
||||
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
|
||||
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
|
||||
std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
|
||||
std::vector<const Shape*> inits = {&f32_, &s32_};
|
||||
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
||||
{f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
||||
std::vector<int64> window_dimensions = {1, 2, 4};
|
||||
std::vector<int64> window_strides = {1, 1, 1};
|
||||
std::vector<std::pair<int64, int64>> padding_values =
|
||||
MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
|
||||
window_strides, Padding::kValid);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Window window,
|
||||
ShapeInference::InferWindowFromDimensions(
|
||||
window_dimensions, window_strides, padding_values, {}, {}));
|
||||
auto inferred_status = ShapeInference::InferReduceWindowShape(
|
||||
absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
|
||||
VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
|
||||
EXPECT_IS_OK(inferred_status.status());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
|
||||
ShapeUtil::MakeShape(S32, {5, 2, 0})}),
|
||||
inferred_status.ValueOrDie()));
|
||||
}
|
||||
|
||||
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
|
||||
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
||||
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
||||
@ -948,6 +975,29 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
|
||||
HasSubstr("must have at least 2 arguments, has 0"));
|
||||
}
|
||||
|
||||
TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
|
||||
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
|
||||
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
|
||||
std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
|
||||
std::vector<const Shape*> inits = {&f32_, &s32_};
|
||||
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
||||
{f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
||||
std::vector<int64> window_dimensions = {1, 2, 4};
|
||||
std::vector<int64> window_strides = {1, 1, 1};
|
||||
std::vector<std::pair<int64, int64>> padding_values =
|
||||
MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
|
||||
window_strides, Padding::kValid);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
Window window,
|
||||
ShapeInference::InferWindowFromDimensions(
|
||||
window_dimensions, window_strides, padding_values, {}, {}));
|
||||
auto inferred_status = ShapeInference::InferReduceWindowShape(
|
||||
absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
|
||||
EXPECT_FALSE(inferred_status.status().ok());
|
||||
EXPECT_THAT(inferred_status.status().error_message(),
|
||||
HasSubstr("f32[] vs s32[]"));
|
||||
}
|
||||
|
||||
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
|
||||
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
||||
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
||||
|
Loading…
x
Reference in New Issue
Block a user