[XLA] Add xla builder support for variadic reduce window op. This is the first CL leading to full support of variaduc reduce window.

PiperOrigin-RevId: 338417352
Change-Id: I8e5907f0ddf2a29081c4d84d593b30f5c3eda6ed
This commit is contained in:
A. Unique TensorFlower 2020-10-22 00:13:09 -07:00 committed by TensorFlower Gardener
parent abfdb66c8e
commit 96bd0a0b68
15 changed files with 322 additions and 24 deletions

View File

@ -2324,31 +2324,53 @@ XlaOp XlaBuilder::ReduceWindow(XlaOp operand, XlaOp init_value,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_RETURN_IF_ERROR(
ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
window_dimensions, window_strides));
return ReduceWindow(absl::MakeSpan(&operand, 1),
absl::MakeSpan(&init_value, 1), computation,
window_dimensions, window_strides, padding);
}
XlaOp XlaBuilder::ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
const Shape* operand_shape = nullptr;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(operand_shape, GetShapePtr(operand));
TF_RETURN_IF_ERROR(
ValidatePaddingValues(AsInt64Slice(operand_shape->dimensions()),
window_dimensions, window_strides));
}
CHECK(operand_shape != nullptr);
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(operand_shape->dimensions()),
window_dimensions, window_strides, padding);
return ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
operands, init_values, computation, window_dimensions, window_strides,
/*base_dilations=*/{}, /*window_dilations=*/{}, padding_values);
});
}
XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
XlaOp operand, XlaOp init_value, const XlaComputation& computation,
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
absl::Span<const int64> base_dilations,
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
std::vector<const Shape*> operand_shapes, init_shapes;
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
for (int i = 0; i < operands.size(); ++i) {
const auto& operand = operands[i];
const auto& init_value = init_values[i];
TF_ASSIGN_OR_RETURN(const Shape* operand_shape, GetShapePtr(operand));
operand_shapes.push_back(operand_shape);
TF_ASSIGN_OR_RETURN(const Shape* init_shape, GetShapePtr(init_value));
init_shapes.push_back(init_shape);
}
TF_ASSIGN_OR_RETURN(const ProgramShape& to_apply_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(auto window,
@ -2358,12 +2380,33 @@ XlaOp XlaBuilder::ReduceWindowWithGeneralPadding(
/*rhs_dilation=*/window_dilations));
TF_ASSIGN_OR_RETURN(
Shape shape, ShapeInference::InferReduceWindowShape(
*operand_shape, *init_shape, window, to_apply_shape));
return ReduceWindowInternal(shape, operand, init_value, computation,
absl::MakeSpan(operand_shapes),
absl::MakeSpan(init_shapes), window, to_apply_shape));
return ReduceWindowInternal(shape, operands, init_values, computation,
std::move(window));
});
}
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
const Shape& shape, absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values, const XlaComputation& computation,
Window window) {
if (operands.size() == 1) {
return ReduceWindowInternal(shape, operands[0], init_values[0], computation,
window);
} else {
HloInstructionProto instr;
*instr.mutable_shape() = shape.ToProto();
*instr.mutable_window() = std::move(window);
AddCalledComputation(computation, &instr);
std::vector<XlaOp> args;
args.insert(args.end(), operands.begin(), operands.end());
args.insert(args.end(), init_values.begin(), init_values.end());
return AddInstruction(std::move(instr), HloOpcode::kReduceWindow,
absl::MakeSpan(args));
}
}
StatusOr<XlaOp> XlaBuilder::ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window) {
@ -4067,6 +4110,17 @@ XlaOp ReduceWindow(const XlaOp operand, const XlaOp init_value,
padding);
}
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, Padding padding) {
CHECK(!operands.empty());
return operands[0].builder()->ReduceWindow(operands, init_values, computation,
window_dimensions, window_strides,
padding);
}
XlaOp ReduceWindowWithGeneralPadding(
const XlaOp operand, const XlaOp init_value,
const XlaComputation& computation,
@ -4076,8 +4130,9 @@ XlaOp ReduceWindowWithGeneralPadding(
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding) {
return operand.builder()->ReduceWindowWithGeneralPadding(
operand, init_value, computation, window_dimensions, window_strides,
base_dilations, window_dilations, padding);
absl::MakeSpan(&operand, 1), absl::MakeSpan(&init_value, 1), computation,
window_dimensions, window_strides, base_dilations, window_dilations,
padding);
}
XlaOp AllGather(const XlaOp operand, int64 all_gather_dimension,

View File

@ -648,18 +648,28 @@ class XlaBuilder {
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, Padding padding);
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, Padding padding);
XlaOp ReduceWindowWithGeneralPadding(
XlaOp operand, XlaOp init_value, const XlaComputation& computation,
absl::Span<const XlaOp> operands, absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
absl::Span<const int64> base_dilations,
absl::Span<const int64> window_dilations,
absl::Span<const std::pair<int64, int64>> padding);
StatusOr<XlaOp> ReduceWindowInternal(const Shape& shape,
absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
Window window);
virtual StatusOr<XlaOp> ReduceWindowInternal(
const Shape& shape, XlaOp operand, XlaOp init_value,
const XlaComputation& computation, Window window);
XlaOp CrossReplicaSum(XlaOp operand,
absl::Span<const ReplicaGroup> replica_groups = {});
@ -1137,6 +1147,12 @@ class XlaBuilder {
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding);
friend XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides,
Padding padding);
friend XlaOp ReduceWindowWithGeneralPadding(
XlaOp operand, XlaOp init_value, const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
@ -1965,6 +1981,12 @@ XlaOp ReduceWindow(XlaOp operand, XlaOp init_value,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, Padding padding);
XlaOp ReduceWindow(absl::Span<const XlaOp> operands,
absl::Span<const XlaOp> init_values,
const XlaComputation& computation,
absl::Span<const int64> window_dimensions,
absl::Span<const int64> window_strides, Padding padding);
// As ReduceWindow(), but the padding is given in the format
// returned by MakePadding().
XlaOp ReduceWindowWithGeneralPadding(

View File

@ -873,6 +873,8 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) {
ReduceWindow(gte, init, sum, /*window_dimensions=*/{1, 2, 4},
/*window_strides=*/{1, 1, 1}, Padding::kValid);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
VLOG(2) << module->entry_computation()->root_instruction()->ToString()
<< "\n";
const Shape& result_shape =
module->entry_computation()->root_instruction()->shape();
EXPECT_TRUE(
@ -880,6 +882,46 @@ TEST_F(XlaBuilderTest, DynamicReduceWindow) {
<< result_shape;
}
TEST_F(XlaBuilderTest, VariadicDynamicReduceWindow) {
XlaBuilder b(TestName());
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
{ShapeUtil::MakeShape(F32, {2, 4, 8}, {true, false, false}),
ShapeUtil::MakeShape(U32, {})});
auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
auto p1 = Parameter(&b, 1, tuple_param_shape, "p1");
ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
/*dynamic_size_param_index=*/{1},
/*target_param_num=*/0,
/*target_param_index=*/{0},
/*target_dim_num=*/0));
auto gte0 = GetTupleElement(p0, 0);
auto gte1 = GetTupleElement(p1, 0);
std::vector<XlaOp> input_operands = {gte0, gte1};
XlaBuilder bsum(TestName());
auto p2 = Parameter(&bsum, 0, ShapeUtil::MakeShape(F32, {}), "x0");
auto p3 = Parameter(&bsum, 1, ShapeUtil::MakeShape(F32, {}), "x1");
auto p4 = Parameter(&bsum, 2, ShapeUtil::MakeShape(F32, {}), "y0");
auto p5 = Parameter(&bsum, 3, ShapeUtil::MakeShape(F32, {}), "y1");
std::vector<XlaOp> output_operands = {Add(p2, p4), Add(p3, p5)};
Tuple(&bsum, absl::MakeSpan(output_operands));
TF_ASSERT_OK_AND_ASSIGN(auto sum, bsum.Build());
auto init = ConstantR0<float>(&b, 0.f);
ReduceWindow(input_operands, {init, init}, sum,
/*window_dimensions=*/{1, 2, 4},
/*window_strides=*/{1, 1, 1}, Padding::kValid);
TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
VLOG(2) << module->entry_computation()->root_instruction()->ToString()
<< "\n";
const Shape& result_shape =
module->entry_computation()->root_instruction()->shape();
EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(0).dynamic_dimensions(),
{true, false, false}))
<< result_shape.tuple_shapes(0);
EXPECT_TRUE(ContainersEqual(result_shape.tuple_shapes(1).dynamic_dimensions(),
{true, false, false}))
<< result_shape.tuple_shapes(1);
}
TEST_F(XlaBuilderTest, DynamicSelectAndScatter) {
XlaBuilder b(TestName());
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(

View File

@ -241,7 +241,10 @@ tf_cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:window_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",

View File

@ -2540,6 +2540,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalReduceWindow(
// if I in bounds of input
// value = function(value, input[I])
// output[O] = value
if (reduce_window->shape().IsTuple()) {
return Status(tensorflow::error::UNIMPLEMENTED,
"Variadic reduce window op is not yet fully supported.");
}
const HloInstruction* operand = reduce_window->operand(0);
const Window& window = reduce_window->window();

View File

@ -1932,6 +1932,10 @@ class HloEvaluatorTypedVisitor : public DfsHloVisitorWithDefault {
}
Status HandleReduceWindow(HloInstruction* reduce_window) override {
if (reduce_window->shape().IsTuple()) {
return Status(tensorflow::error::UNIMPLEMENTED,
"Variadic reduce window op is not yet fully supported.");
}
auto operand = reduce_window->operand(0);
const Window& window = reduce_window->window();
HloComputation* function = reduce_window->to_apply();

View File

@ -515,11 +515,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
}
case HloOpcode::kReduceWindow:
TF_RET_CHECK(proto.operand_ids_size() % 2 == 0)
<< "Reduce window should have an even number of operands but "
"sees "
<< proto.operand_ids_size();
TF_RET_CHECK(proto.called_computation_ids_size() == 1)
<< "ReduceWindow should have 1 called computation but sees "
<< proto.called_computation_ids_size();
instruction = CreateReduceWindow(shape, operands(0), operands(1),
proto.window(), computations(0));
{
const auto reduce_operands = all_operands();
auto inputs = absl::MakeSpan(reduce_operands)
.subspan(0, reduce_operands.size() / 2);
auto init_values =
absl::MakeSpan(reduce_operands)
.subspan(reduce_operands.size() / 2, reduce_operands.size());
instruction = CreateReduceWindow(shape, inputs, init_values,
proto.window(), computations(0));
}
break;
case HloOpcode::kSelectAndScatter:
TF_RET_CHECK(proto.called_computation_ids_size() == 2)
@ -1273,6 +1285,13 @@ HloInstruction::CreateBitcastConvert(const Shape& shape,
shape, operand, init_value, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReduceWindow(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values, const Window& window,
HloComputation* reduce_computation) {
return absl::make_unique<HloReduceWindowInstruction>(
shape, operands, init_values, window, reduce_computation);
}
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateBatchNormTraining(const Shape& shape,
HloInstruction* operand,

View File

@ -830,6 +830,16 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation);
// A more general, multiple-argument version of the above.
// The reduce_computation being applied,now takes N arguments:
// [accumulator0, accumulator1, ..., accumulatorN, value0, value1, ...,
// valueN], and returns an N-tuple. The operands and init_values now each
// contain a span of N input arrays and n initial values.
static std::unique_ptr<HloInstruction> CreateReduceWindow(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values, const Window& window,
HloComputation* reduce_computation);
// Creates a batch-norm-training instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,

View File

@ -2237,9 +2237,21 @@ HloConvolutionInstruction::CloneWithNewOperandsImpl(
HloReduceWindowInstruction::HloReduceWindowInstruction(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation)
: HloReduceWindowInstruction(shape, absl::MakeSpan(&operand, 1),
absl::MakeSpan(&init_value, 1), window,
reduce_computation) {}
HloReduceWindowInstruction::HloReduceWindowInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values, const Window& window,
HloComputation* reduce_computation)
: HloInstruction(HloOpcode::kReduceWindow, shape), window_(window) {
AppendOperand(operand);
AppendOperand(init_value);
for (auto* operand : operands) {
AppendOperand(operand);
}
for (auto* init_value : init_values) {
AppendOperand(init_value);
}
AppendComputation(reduce_computation);
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/shape.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
@ -1294,10 +1295,43 @@ class HloReduceWindowInstruction : public HloInstruction {
HloInstruction* init_value,
const Window& window,
HloComputation* reduce_computation);
explicit HloReduceWindowInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::Span<HloInstruction* const> init_values, const Window& window,
HloComputation* reduce_computation);
const Window& window() const override { return window_; }
void set_window(const Window& window) override { window_ = window; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
// Returns the number of input arrays (and, consequentially, the number of
// init values) this reduce has.
int64 input_count() const { return operand_count() / 2; }
// Returns the input tensors to be reduced.
absl::Span<HloInstruction* const> input_arrays() const {
return absl::MakeSpan(operands()).subspan(0, input_count());
}
// Returns the init values of the reduction.
absl::Span<HloInstruction* const> init_values() const {
return absl::MakeSpan(operands()).subspan(input_count(), operand_count());
}
// Returns the shapes of input tensors to be reduced.
absl::InlinedVector<const Shape*, 2> input_array_shapes() const {
absl::InlinedVector<const Shape*, 2> shapes;
for (const auto* op : input_arrays()) {
VLOG(2) << "Pushing input array shape for: " << op->ToString() << "\n";
shapes.push_back(&op->shape());
VLOG(2) << "Pushed shape: " << shapes.back()->ToString() << "\n";
}
return shapes;
}
// Returns the init values of the reduction.
absl::InlinedVector<const Shape*, 2> init_value_shapes() const {
absl::InlinedVector<const Shape*, 2> shapes;
for (const auto* op : init_values()) {
shapes.push_back(&op->shape());
}
return shapes;
}
private:
std::vector<string> ExtraAttributesToStringImpl(
@ -1310,6 +1344,7 @@ class HloReduceWindowInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
Window window_;
};

View File

@ -119,7 +119,7 @@ namespace xla {
V(kRecvDone, "recv-done", 1) \
V(kReduce, "reduce", kHloOpcodeIsVariadic) \
V(kReducePrecision, "reduce-precision", 1) \
V(kReduceWindow, "reduce-window", 2) \
V(kReduceWindow, "reduce-window", kHloOpcodeIsVariadic) \
V(kRemainder, "remainder", 2) \
V(kReplicaId, "replica-id", 0) \
V(kReshape, "reshape", 1) \

View File

@ -65,6 +65,7 @@ TEST(HloOpcodeTest, OpcodeProperties) {
case HloOpcode::kRng:
case HloOpcode::kSort:
case HloOpcode::kTuple:
case HloOpcode::kReduceWindow:
EXPECT_TRUE(HloOpcodeIsVariadic(opcode));
break;
default:

View File

@ -2084,7 +2084,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
arg_shapes.size());
}
int64 num_reduced_args = arg_shapes.size() / 2;
auto reduced_args = arg_shapes.subspan(0, num_reduced_args);
// Check that all of the reduced tensors have the same dimensions. The element
// types may be different.
@ -2097,7 +2096,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
ShapeUtil::HumanString(*reduced_args[i]));
}
}
// Check that the dimensions to reduce are in-bounds for the given shape.
// We've already verified all reduced tensors have the same dimensions, so it
// doesn't matter which one we choose.
@ -2156,6 +2154,43 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return InferReduceWindowShape(operand_shape, init_value_shape, window);
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
const Window& window, const ProgramShape& to_apply_shape) {
auto number_of_input = operands.size();
// Check that all of the reduced tensors have the same dimensions. The element
// types may be different.
for (int64 i = 1; i < number_of_input; ++i) {
if (!ShapeUtil::SameDimensions(*operands[0], *operands[i])) {
return InvalidArgument(
"All reduced tensors must have the same dimension. Tensor 0 has "
"shape %s, Tensor %d has shape %s",
ShapeUtil::HumanString(*operands[0]), i,
ShapeUtil::HumanString(*operands[i]));
}
}
std::vector<PrimitiveType> operand_element_type_vec;
for (const Shape* s : operands) {
operand_element_type_vec.push_back(s->element_type());
}
TF_RETURN_IF_ERROR(VerifyReducerShape(to_apply_shape, init_values,
operand_element_type_vec,
/*inputs=*/number_of_input));
std::vector<Shape> output_shape_vec;
for (int i = 0; i < operands.size(); ++i) {
TF_ASSIGN_OR_RETURN(
auto cur_output_shape,
InferReduceWindowShape(*operands[i], *init_values[i], window));
output_shape_vec.push_back(cur_output_shape);
}
if (ShapeUtil::IsScalar(to_apply_shape.result())) {
CHECK_EQ(output_shape_vec.size(), 1);
return output_shape_vec[0];
} else {
return ShapeUtil::MakeTupleShape(output_shape_vec);
}
}
/* static */ StatusOr<Shape> ShapeInference::InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value_shape,
const Window& window) {

View File

@ -164,10 +164,16 @@ class ShapeInference {
static StatusOr<Shape> InferReduceWindowShape(
const Shape& operand_shape, const Shape& init_value, const Window& window,
const ProgramShape& to_apply_shape);
static StatusOr<Shape> InferReduceWindowShape(const Shape& operand_shape,
const Shape& init_value,
const Window& window);
static StatusOr<Shape> InferReduceWindowShape(
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
const Window& window, const ProgramShape& to_apply_shape);
static StatusOr<Shape> InferReduceWindowShape(
absl::Span<const Shape*> operands, absl::Span<const Shape*> init_values,
const Window& window);
// Infers the shape produced by scattering the given source shape to the
// selected indices of each window on the operand shape.

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
@ -912,6 +913,32 @@ TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
inferred_status.ValueOrDie()));
}
TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
std::vector<const Shape*> inits = {&f32_, &s32_};
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
{f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
std::vector<int64> window_dimensions = {1, 2, 4};
std::vector<int64> window_strides = {1, 1, 1};
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
window_strides, Padding::kValid);
TF_ASSERT_OK_AND_ASSIGN(
Window window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding_values, {}, {}));
auto inferred_status = ShapeInference::InferReduceWindowShape(
absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
EXPECT_IS_OK(inferred_status.status());
EXPECT_TRUE(ShapeUtil::Equal(
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
ShapeUtil::MakeShape(S32, {5, 2, 0})}),
inferred_status.ValueOrDie()));
}
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
@ -948,6 +975,29 @@ TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
HasSubstr("must have at least 2 arguments, has 0"));
}
TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
std::vector<const Shape*> inits = {&f32_, &s32_};
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
{f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
std::vector<int64> window_dimensions = {1, 2, 4};
std::vector<int64> window_strides = {1, 1, 1};
std::vector<std::pair<int64, int64>> padding_values =
MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
window_strides, Padding::kValid);
TF_ASSERT_OK_AND_ASSIGN(
Window window,
ShapeInference::InferWindowFromDimensions(
window_dimensions, window_strides, padding_values, {}, {}));
auto inferred_status = ShapeInference::InferReduceWindowShape(
absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
EXPECT_FALSE(inferred_status.status().ok());
EXPECT_THAT(inferred_status.status().error_message(),
HasSubstr("f32[] vs s32[]"));
}
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});