conv(a, b) should infer the same element type as conv(b, a). We did not have enough tie breaking logic to ensure this would happen. PiperOrigin-RevId: 351426686 Change-Id: Ibd7c0e9c17101c2b95a329c5b66d3b4e77aaae95
3499 lines
148 KiB
C++
3499 lines
148 KiB
C++
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
==============================================================================*/
|
|
|
|
#include "tensorflow/compiler/xla/service/shape_inference.h"
|
|
|
|
#include <string>
|
|
|
|
#include "absl/strings/string_view.h"
|
|
#include "absl/types/span.h"
|
|
#include "tensorflow/compiler/xla/client/padding.h"
|
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
|
#include "tensorflow/compiler/xla/shape_util.h"
|
|
#include "tensorflow/compiler/xla/test.h"
|
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
|
#include "tensorflow/compiler/xla/types.h"
|
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
|
|
|
namespace xla {
|
|
namespace {
|
|
|
|
using ::testing::ContainsRegex;
|
|
using ::testing::HasSubstr;
|
|
|
|
class ShapeInferenceTest : public ::testing::Test {
|
|
protected:
|
|
// Some handy scalar shapes.
|
|
const Shape s32_ = ShapeUtil::MakeShape(S32, {});
|
|
const Shape f16_ = ShapeUtil::MakeShape(F16, {});
|
|
const Shape f32_ = ShapeUtil::MakeShape(F32, {});
|
|
const Shape f64_ = ShapeUtil::MakeShape(F64, {});
|
|
const Shape pred_ = ShapeUtil::MakeShape(PRED, {});
|
|
|
|
// Some handy vector and matrix shapes of F32 type.
|
|
// Suffix: vector_length_, matrix_rows_cols_
|
|
const Shape vector_32_ = ShapeUtil::MakeShape(F32, {32});
|
|
const Shape vector_64_ = ShapeUtil::MakeShape(F32, {64});
|
|
const Shape matrix_32_48_ = ShapeUtil::MakeShape(F32, {32, 48});
|
|
const Shape matrix_32_64_ = ShapeUtil::MakeShape(F32, {32, 64});
|
|
const Shape matrix_64_48_ = ShapeUtil::MakeShape(F32, {64, 48});
|
|
|
|
// Some handy S32 arrays.
|
|
const Shape s32matrix_64_64_ = ShapeUtil::MakeShape(S32, {64, 64});
|
|
};
|
|
|
|
// Subclass for testing InferReduceShape.
|
|
class ReduceShapeInferenceTest : public ShapeInferenceTest {
|
|
protected:
|
|
// Helper that runs reduce shape inference with the input 'arg' and given
|
|
// dimensions to reduce, and checks the inferred shape is as expected. The
|
|
// element type here is hard-coded to F32.
|
|
void ExpectInferredReduceShape(const Shape& expected_inferred_shape,
|
|
const Shape& arg,
|
|
absl::Span<const int64> dimensions_to_reduce) {
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&arg, &f32_}, dimensions_to_reduce, to_apply);
|
|
EXPECT_IS_OK(inferred_status.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(expected_inferred_shape,
|
|
inferred_status.ValueOrDie()));
|
|
}
|
|
};
|
|
|
|
// Subclass for testing InferSelectAndScatterShape.
|
|
class SelectAndScatterShapeInferenceTest : public ShapeInferenceTest {
|
|
protected:
|
|
SelectAndScatterShapeInferenceTest() {
|
|
operand_shape_ = ShapeUtil::MakeShape(F32, {8, 16});
|
|
source_shape_ = ShapeUtil::MakeShape(F32, {4, 8});
|
|
WindowDimension dim;
|
|
dim.set_size(2);
|
|
dim.set_stride(2);
|
|
dim.set_padding_low(0);
|
|
dim.set_padding_high(0);
|
|
dim.set_window_dilation(1);
|
|
dim.set_base_dilation(1);
|
|
*window_.add_dimensions() = dim;
|
|
*window_.add_dimensions() = dim;
|
|
init_value_shape_ = ShapeUtil::MakeShape(F32, {});
|
|
select_program_shape_ = ShapeUtil::MakeProgramShape(
|
|
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
|
|
scatter_program_shape_ = ShapeUtil::MakeProgramShape(
|
|
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
|
|
}
|
|
|
|
Shape operand_shape_;
|
|
Shape source_shape_;
|
|
Window window_;
|
|
Shape init_value_shape_;
|
|
ProgramShape select_program_shape_;
|
|
ProgramShape scatter_program_shape_;
|
|
};
|
|
|
|
TEST_F(ShapeInferenceTest, UnaryNegateMatrix) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
|
auto inferred_status =
|
|
ShapeInference::InferUnaryOpShape(HloOpcode::kNegate, matrix_shape);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_shape, inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenTuples) {
|
|
Shape tuple = ShapeUtil::MakeTupleShape({s32_, f32_});
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, pred_, tuple, tuple);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Expected array argument for select"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, SelectScalarPredBetweenArrays) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, pred_, matrix_64_48_, matrix_64_48_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(
|
|
inferred_status.status().error_message(),
|
|
HasSubstr("Operands to select and predicate must be the same shape"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, SelectArrayPredBetweenArrays) {
|
|
auto predarray = ShapeUtil::MakeShape(PRED, {64, 48});
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, predarray, matrix_64_48_, matrix_64_48_);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, SelectBadShapes) {
|
|
auto inferred_status_error1 = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, pred_, matrix_64_48_, matrix_32_64_);
|
|
ASSERT_FALSE(inferred_status_error1.ok());
|
|
ASSERT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("Operands to select must be the same shape"));
|
|
|
|
auto inferred_status_error2 = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, s32_, matrix_64_48_, matrix_64_48_);
|
|
ASSERT_FALSE(inferred_status_error2.ok());
|
|
ASSERT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("pred operand must have PRED"));
|
|
|
|
auto inferred_status_error3 = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, ShapeUtil::MakeShape(PRED, {64}), matrix_64_48_,
|
|
matrix_64_48_);
|
|
ASSERT_FALSE(inferred_status_error3.ok());
|
|
ASSERT_THAT(
|
|
inferred_status_error3.status().error_message(),
|
|
HasSubstr("Operands to select and predicate must be the same shape"));
|
|
|
|
// Tuples have a TUPLE element type and cannot be the pred of a select.
|
|
auto inferred_status_error4 = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kSelect, ShapeUtil::MakeTupleShape({pred_, pred_}),
|
|
ShapeUtil::MakeTupleShape({f32_, f32_}),
|
|
ShapeUtil::MakeTupleShape({f32_, f32_}));
|
|
ASSERT_FALSE(inferred_status_error4.ok());
|
|
ASSERT_THAT(inferred_status_error4.status().error_message(),
|
|
HasSubstr("Expected array argument for select pred"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampAllMatrix) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, matrix_64_48_);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(matrix_64_48_, inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampAllScalar) {
|
|
auto inferred_status =
|
|
ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, f32_);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampMinScalar) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, f32_, matrix_64_48_, matrix_64_48_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Clamp with different shapes"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampMaxScalar) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, matrix_64_48_, matrix_64_48_, f32_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Clamp with different shapes"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampOperandScalar) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, matrix_64_48_, f32_, matrix_64_48_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Clamp with different shapes"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampMinMatrix) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, matrix_64_48_, f32_, f32_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Clamp with different shapes"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampMaxMatrix) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, f32_, f32_, matrix_64_48_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Clamp with different shapes"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampOperandMatrix) {
|
|
auto inferred_status = ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, f32_, matrix_64_48_, f32_);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Clamp with different shapes"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ClampBadShapes) {
|
|
// Type mismatch
|
|
ASSERT_FALSE(
|
|
ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, s32_, f32_, f32_)
|
|
.ok());
|
|
ASSERT_FALSE(
|
|
ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, s32_, f32_)
|
|
.ok());
|
|
ASSERT_FALSE(
|
|
ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_, f32_, s32_)
|
|
.ok());
|
|
// Dimension mismatch
|
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, vector_64_, vector_32_, vector_32_)
|
|
.ok());
|
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, vector_32_, vector_64_, vector_32_)
|
|
.ok());
|
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(
|
|
HloOpcode::kClamp, vector_32_, vector_32_, vector_64_)
|
|
.ok());
|
|
// Dimension mismatch, where one operand is a scalar
|
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
|
|
vector_64_, vector_32_, f32_)
|
|
.ok());
|
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp,
|
|
vector_64_, f32_, vector_32_)
|
|
.ok());
|
|
ASSERT_FALSE(ShapeInference::InferTernaryOpShape(HloOpcode::kClamp, f32_,
|
|
vector_64_, vector_32_)
|
|
.ok());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Complex) {
|
|
auto complex_shape = [&](const Shape& lhs, const Shape& rhs,
|
|
absl::Span<const int64> bcast) {
|
|
return ShapeInference::InferBinaryOpShape(HloOpcode::kComplex, lhs, rhs,
|
|
bcast);
|
|
};
|
|
// Inputs must be FP.
|
|
ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok());
|
|
ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok());
|
|
// Component types must match.
|
|
ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok());
|
|
// Only F32->C64 and F64->C128 supported.
|
|
ASSERT_FALSE(complex_shape(f16_, f16_, {}).ok());
|
|
// Validate correct uses.
|
|
Shape c64_32 = ShapeUtil::MakeShape(C64, {32});
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {})));
|
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32));
|
|
|
|
Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64});
|
|
TF_ASSERT_OK_AND_ASSIGN(result,
|
|
complex_shape(vector_64_, matrix_32_64_, {1}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
|
TF_ASSERT_OK_AND_ASSIGN(result,
|
|
complex_shape(matrix_32_64_, vector_64_, {1}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
|
TF_ASSERT_OK_AND_ASSIGN(result,
|
|
complex_shape(matrix_32_64_, matrix_32_64_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64));
|
|
|
|
TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f64_, f64_, {}));
|
|
ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C128, {})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, VariadicOpTuplify) {
|
|
StatusOr<Shape> result =
|
|
ShapeInference::InferVariadicOpShape(HloOpcode::kTuple, {&s32_, &f32_});
|
|
ASSERT_IS_OK(result.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(result.ValueOrDie(),
|
|
ShapeUtil::MakeTupleShape({s32_, f32_})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ReduceWindowInHalf) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {8, 8});
|
|
Window window;
|
|
WindowDimension dim;
|
|
dim.set_size(2);
|
|
dim.set_stride(2);
|
|
dim.set_padding_low(0);
|
|
dim.set_padding_high(0);
|
|
dim.set_window_dilation(1);
|
|
dim.set_base_dilation(1);
|
|
*window.add_dimensions() = dim;
|
|
*window.add_dimensions() = dim;
|
|
Shape window_shape = ShapeUtil::MakeShape(F32, {2, 2});
|
|
Shape init_value_shape = ShapeUtil::MakeShape(F32, {});
|
|
Shape float_scalar = ShapeUtil::MakeShape(F32, {});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
|
|
auto inferred_status = ShapeInference::InferReduceWindowShape(
|
|
matrix_shape, init_value_shape, window, to_apply);
|
|
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {4, 4}), inferred));
|
|
}
|
|
|
|
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterProperShapes) {
|
|
auto inferred_status_ok = ShapeInference::InferSelectAndScatterShape(
|
|
operand_shape_, select_program_shape_, window_, source_shape_,
|
|
init_value_shape_, scatter_program_shape_);
|
|
ASSERT_IS_OK(inferred_status_ok.status());
|
|
Shape inferred = inferred_status_ok.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(operand_shape_, inferred));
|
|
}
|
|
|
|
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSourceShape) {
|
|
Shape source_shape_fail = ShapeUtil::MakeShape(F32, {4, 6});
|
|
auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
|
|
operand_shape_, select_program_shape_, window_, source_shape_fail,
|
|
init_value_shape_, scatter_program_shape_);
|
|
ASSERT_FALSE(inferred_status_fail.ok());
|
|
ASSERT_THAT(inferred_status_fail.status().error_message(),
|
|
HasSubstr("Source shape does not match"));
|
|
}
|
|
|
|
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape1) {
|
|
ProgramShape select_program_shape_fail =
|
|
ShapeUtil::MakeProgramShape({ShapeUtil::MakeShape(F32, {})}, pred_);
|
|
auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
|
|
operand_shape_, select_program_shape_fail, window_, source_shape_,
|
|
init_value_shape_, scatter_program_shape_);
|
|
ASSERT_FALSE(inferred_status_fail.ok());
|
|
ASSERT_THAT(inferred_status_fail.status().error_message(),
|
|
HasSubstr("Select function must take 2 parameters"));
|
|
}
|
|
|
|
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape2) {
|
|
ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
|
|
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}, f32_);
|
|
auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
|
|
operand_shape_, select_program_shape_fail, window_, source_shape_,
|
|
init_value_shape_, scatter_program_shape_);
|
|
ASSERT_FALSE(inferred_status_fail.ok());
|
|
ASSERT_THAT(inferred_status_fail.status().error_message(),
|
|
HasSubstr("Select function must have rank-0 PRED"));
|
|
}
|
|
|
|
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape3) {
|
|
ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
|
|
{ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {})}, pred_);
|
|
auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
|
|
operand_shape_, select_program_shape_fail, window_, source_shape_,
|
|
init_value_shape_, scatter_program_shape_);
|
|
ASSERT_FALSE(inferred_status_fail.ok());
|
|
ASSERT_THAT(inferred_status_fail.status().error_message(),
|
|
HasSubstr("Select function's first parameter"));
|
|
}
|
|
|
|
TEST_F(SelectAndScatterShapeInferenceTest, SelectAndScatterWrongSelectShape4) {
|
|
ProgramShape select_program_shape_fail = ShapeUtil::MakeProgramShape(
|
|
{ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(U32, {})}, pred_);
|
|
auto inferred_status_fail = ShapeInference::InferSelectAndScatterShape(
|
|
operand_shape_, select_program_shape_fail, window_, source_shape_,
|
|
init_value_shape_, scatter_program_shape_);
|
|
ASSERT_FALSE(inferred_status_fail.ok());
|
|
ASSERT_THAT(inferred_status_fail.status().error_message(),
|
|
HasSubstr("Select function's second parameter"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Convolve) {
|
|
ConvolutionDimensionNumbers dnums;
|
|
|
|
// Dimension order: batch, feature, x0, x1
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
|
|
dnums.set_input_batch_dimension(0);
|
|
dnums.set_output_batch_dimension(0);
|
|
dnums.set_input_feature_dimension(1);
|
|
dnums.set_output_feature_dimension(1);
|
|
dnums.add_input_spatial_dimensions(2);
|
|
dnums.add_output_spatial_dimensions(2);
|
|
dnums.add_input_spatial_dimensions(3);
|
|
dnums.add_output_spatial_dimensions(3);
|
|
|
|
// Dimension order: x1, batch, feature, x0
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
|
|
dnums.set_kernel_input_feature_dimension(2);
|
|
dnums.set_kernel_output_feature_dimension(1);
|
|
dnums.add_kernel_spatial_dimensions(3);
|
|
dnums.add_kernel_spatial_dimensions(0);
|
|
|
|
Window window;
|
|
auto dim0 = window.add_dimensions();
|
|
auto dim1 = window.add_dimensions();
|
|
dim0->set_size(3);
|
|
dim0->set_stride(2);
|
|
dim0->set_padding_low(1);
|
|
dim0->set_padding_high(1);
|
|
dim0->set_window_dilation(1);
|
|
dim0->set_base_dilation(1);
|
|
dim1->set_size(2);
|
|
dim1->set_stride(1);
|
|
dim1->set_padding_low(0);
|
|
dim1->set_padding_high(0);
|
|
dim1->set_window_dilation(1);
|
|
dim1->set_base_dilation(1);
|
|
auto inferred_status = ShapeInference::InferConvolveShape(
|
|
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
|
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred_shape = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) {
|
|
ConvolutionDimensionNumbers dnums;
|
|
|
|
// Dimension order: batch, feature, x0, x1
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4});
|
|
dnums.set_input_batch_dimension(0);
|
|
dnums.set_output_batch_dimension(0);
|
|
dnums.set_input_feature_dimension(1);
|
|
dnums.set_output_feature_dimension(1);
|
|
dnums.add_input_spatial_dimensions(2);
|
|
dnums.add_output_spatial_dimensions(2);
|
|
dnums.add_input_spatial_dimensions(3);
|
|
dnums.add_output_spatial_dimensions(3);
|
|
|
|
// Dimension order: x1, batch, feature, x0
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 3});
|
|
dnums.set_kernel_input_feature_dimension(2);
|
|
dnums.set_kernel_output_feature_dimension(1);
|
|
dnums.add_kernel_spatial_dimensions(3);
|
|
dnums.add_kernel_spatial_dimensions(0);
|
|
|
|
Window window;
|
|
auto dim0 = window.add_dimensions();
|
|
dim0->set_size(3);
|
|
dim0->set_stride(3);
|
|
dim0->set_padding_low(0);
|
|
dim0->set_padding_high(0);
|
|
dim0->set_window_dilation(6);
|
|
dim0->set_base_dilation(1);
|
|
|
|
auto dim1 = window.add_dimensions();
|
|
dim1->set_size(2);
|
|
dim1->set_stride(1);
|
|
dim1->set_padding_low(2);
|
|
dim1->set_padding_high(1);
|
|
dim1->set_window_dilation(2);
|
|
dim1->set_base_dilation(1);
|
|
auto inferred_status = ShapeInference::InferConvolveShape(
|
|
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
|
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred_shape = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 31, 5}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) {
|
|
ConvolutionDimensionNumbers dnums;
|
|
|
|
// Dimension order: batch, feature, x0, x1
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
|
|
dnums.set_input_batch_dimension(0);
|
|
dnums.set_output_batch_dimension(0);
|
|
dnums.set_input_feature_dimension(1);
|
|
dnums.set_output_feature_dimension(1);
|
|
dnums.add_input_spatial_dimensions(2);
|
|
dnums.add_output_spatial_dimensions(2);
|
|
dnums.add_input_spatial_dimensions(3);
|
|
dnums.add_output_spatial_dimensions(3);
|
|
|
|
// Dimension order: x1, batch, feature, x0
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 12, 11, 4});
|
|
dnums.set_kernel_input_feature_dimension(2);
|
|
dnums.set_kernel_output_feature_dimension(1);
|
|
dnums.add_kernel_spatial_dimensions(3);
|
|
dnums.add_kernel_spatial_dimensions(0);
|
|
|
|
Window window;
|
|
auto dim0 = window.add_dimensions();
|
|
dim0->set_size(4);
|
|
dim0->set_stride(3);
|
|
dim0->set_padding_low(0);
|
|
dim0->set_padding_high(0);
|
|
dim0->set_window_dilation(1);
|
|
dim0->set_base_dilation(6);
|
|
|
|
auto dim1 = window.add_dimensions();
|
|
dim1->set_size(2);
|
|
dim1->set_stride(1);
|
|
dim1->set_padding_low(2);
|
|
dim1->set_padding_high(1);
|
|
dim1->set_window_dilation(1);
|
|
dim1->set_base_dilation(2);
|
|
auto inferred_status = ShapeInference::InferConvolveShape(
|
|
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
|
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred_shape = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {10, 12, 4, 9}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) {
|
|
// Dimension order for this test: batch, feature, x0, x1
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2});
|
|
|
|
ConvolutionDimensionNumbers dnums;
|
|
dnums.set_input_batch_dimension(3);
|
|
dnums.set_output_batch_dimension(3);
|
|
dnums.set_input_feature_dimension(2);
|
|
dnums.set_output_feature_dimension(2);
|
|
dnums.add_input_spatial_dimensions(0);
|
|
dnums.add_output_spatial_dimensions(0);
|
|
dnums.add_input_spatial_dimensions(1);
|
|
dnums.add_output_spatial_dimensions(1);
|
|
dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0
|
|
dnums.set_kernel_output_feature_dimension(3);
|
|
dnums.add_kernel_spatial_dimensions(0);
|
|
dnums.add_kernel_spatial_dimensions(1);
|
|
|
|
Window window;
|
|
auto dim0 = window.add_dimensions();
|
|
auto dim1 = window.add_dimensions();
|
|
dim0->set_size(2);
|
|
dim0->set_stride(1);
|
|
dim0->set_padding_low(0);
|
|
dim0->set_padding_high(0);
|
|
dim1->set_size(3);
|
|
dim1->set_stride(2);
|
|
dim1->set_padding_low(1);
|
|
dim1->set_padding_high(1);
|
|
auto inferred_status = ShapeInference::InferConvolveShape(
|
|
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/1,
|
|
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("each dimension exactly once"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveBatchGroupCountUnequalOutputFeature) {
|
|
ConvolutionDimensionNumbers dnums;
|
|
dnums.set_input_batch_dimension(0);
|
|
dnums.set_input_feature_dimension(1);
|
|
dnums.add_input_spatial_dimensions(2);
|
|
dnums.add_input_spatial_dimensions(3);
|
|
dnums.set_kernel_input_feature_dimension(0);
|
|
dnums.set_kernel_output_feature_dimension(1);
|
|
dnums.add_kernel_spatial_dimensions(2);
|
|
dnums.add_kernel_spatial_dimensions(3);
|
|
dnums.set_output_batch_dimension(0);
|
|
dnums.set_output_feature_dimension(1);
|
|
dnums.add_output_spatial_dimensions(2);
|
|
dnums.add_output_spatial_dimensions(3);
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {60, 38, 17, 13});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {38, 10, 4, 4});
|
|
Window window;
|
|
auto dim0 = window.add_dimensions();
|
|
auto dim1 = window.add_dimensions();
|
|
dim0->set_size(4);
|
|
dim1->set_size(4);
|
|
dim0->set_padding_low(0);
|
|
dim0->set_padding_high(2);
|
|
dim1->set_padding_low(2);
|
|
dim1->set_padding_high(1);
|
|
dim0->set_stride(1);
|
|
dim1->set_stride(1);
|
|
dim0->set_window_dilation(3);
|
|
dim1->set_window_dilation(2);
|
|
auto inferred_status = ShapeInference::InferConvolveShape(
|
|
lhs_shape, rhs_shape, /*feature_group_count=*/1, /*batch_group_count=*/6,
|
|
window, dnums, /*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("to be a multiple of batch group count"));
|
|
}
|
|
|
|
struct ConvolveArgs {
|
|
Shape lhs_shape;
|
|
Shape rhs_shape;
|
|
ConvolutionDimensionNumbers dnums;
|
|
Window window;
|
|
};
|
|
|
|
ConvolveArgs MakeConvolveArgs(PrimitiveType lhs_type, PrimitiveType rhs_type) {
|
|
ConvolveArgs args;
|
|
ConvolutionDimensionNumbers& dnums = args.dnums;
|
|
|
|
// Dimension order: batch, feature, x0, x1
|
|
args.lhs_shape = ShapeUtil::MakeShape(lhs_type, {10, 11, 3, 4});
|
|
dnums.set_input_batch_dimension(0);
|
|
dnums.set_output_batch_dimension(0);
|
|
dnums.set_input_feature_dimension(1);
|
|
dnums.set_output_feature_dimension(1);
|
|
dnums.add_input_spatial_dimensions(2);
|
|
dnums.add_output_spatial_dimensions(2);
|
|
dnums.add_input_spatial_dimensions(3);
|
|
dnums.add_output_spatial_dimensions(3);
|
|
|
|
// Dimension order: x1, batch, feature, x0
|
|
args.rhs_shape = ShapeUtil::MakeShape(rhs_type, {2, 12, 11, 3});
|
|
dnums.set_kernel_input_feature_dimension(2);
|
|
dnums.set_kernel_output_feature_dimension(1);
|
|
dnums.add_kernel_spatial_dimensions(3);
|
|
dnums.add_kernel_spatial_dimensions(0);
|
|
|
|
auto dim0 = args.window.add_dimensions();
|
|
auto dim1 = args.window.add_dimensions();
|
|
dim0->set_size(3);
|
|
dim0->set_stride(2);
|
|
dim0->set_padding_low(1);
|
|
dim0->set_padding_high(1);
|
|
dim0->set_window_dilation(1);
|
|
dim0->set_base_dilation(1);
|
|
dim1->set_size(2);
|
|
dim1->set_stride(1);
|
|
dim1->set_padding_low(0);
|
|
dim1->set_padding_high(0);
|
|
dim1->set_window_dilation(1);
|
|
dim1->set_base_dilation(1);
|
|
return args;
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithBF16_F16) {
|
|
ConvolveArgs args = MakeConvolveArgs(BF16, F16);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/absl::nullopt))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithF16_BF16) {
|
|
ConvolveArgs args = MakeConvolveArgs(F16, BF16);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/absl::nullopt))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithS32_U32) {
|
|
ConvolveArgs args = MakeConvolveArgs(S32, U32);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/absl::nullopt))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithU32_S32) {
|
|
ConvolveArgs args = MakeConvolveArgs(U32, S32);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/absl::nullopt))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementType) {
|
|
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/S16))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S16, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithPreferredElementTypeSameAsInferredType) {
|
|
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/S32))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(S32, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest,
|
|
FloatingPointConvolveWithNarrowerPreferredElementType) {
|
|
ConvolveArgs args = MakeConvolveArgs(F32, F32);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape inferred_shape,
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/BF16))
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(BF16, {10, 12, 2, 3}),
|
|
inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest,
|
|
FloatingPointConvolveWithInvalidPreferredElementType) {
|
|
ConvolveArgs args = MakeConvolveArgs(BF16, BF16);
|
|
auto inferred_status =
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/S32)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must both be integral or both be floating point"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest,
|
|
IntegralConvolveWithFloatingPointPreferredElementType) {
|
|
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
|
auto inferred_status =
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/F32)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must both be integral or both be floating point"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest,
|
|
ConvolveWithPreferredElementTypeWithDifferentSignedness) {
|
|
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
|
auto inferred_status =
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/U32)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must have the same signedness as the original type"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConvolveWithNarrowerPreferredElementType) {
|
|
ConvolveArgs args = MakeConvolveArgs(S8, S16);
|
|
auto inferred_status =
|
|
ShapeInference::InferConvolveShape(
|
|
args.lhs_shape, args.rhs_shape, /*feature_group_count=*/1,
|
|
/*batch_group_count=*/1, args.window, args.dnums,
|
|
/*preferred_element_type=*/S8)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must not be narrower than the original type"));
|
|
}
|
|
|
|
namespace fft {
|
|
|
|
static const char* unsupported_rank = "only supports ranks 1-3";
|
|
static const char* invalid_rank = "requires input of at least same rank";
|
|
static const char* requires_complex_input = "requires complex input type";
|
|
static const char* requires_f32_input = "requires F32 or F64 input type";
|
|
static const char* dimensions_match = "innermost dimensions match fft_length";
|
|
static const char* innermost_dimension_matches =
|
|
"innermost dimension matches fft_length/2+1";
|
|
|
|
static void Pass(const Shape& shape, FftType type,
|
|
absl::Span<const int64> length, const Shape& expected_shape) {
|
|
auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred_shape = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_shape, expected_shape));
|
|
}
|
|
|
|
static void Fail(const Shape& shape, FftType type,
|
|
absl::Span<const int64> length, absl::string_view message) {
|
|
auto inferred_status = ShapeInference::InferFftShape(shape, type, length);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr(std::string(message)));
|
|
}
|
|
|
|
} // namespace fft
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestFftRanks) {
|
|
FftType type = FftType::FFT;
|
|
Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
|
|
fft::Fail(shape, type, {}, fft::unsupported_rank);
|
|
fft::Pass(shape, type, {8}, shape);
|
|
fft::Pass(shape, type, {16, 8}, shape);
|
|
fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
|
|
fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestFftTypes) {
|
|
FftType type = FftType::FFT;
|
|
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
|
|
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
|
|
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
|
|
fft::Pass(shape_c128, type, {16, 8}, shape_c128);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestIfftRanks) {
|
|
FftType type = FftType::IFFT;
|
|
Shape shape = ShapeUtil::MakeShape(C64, {16, 8});
|
|
fft::Fail(shape, type, {}, fft::unsupported_rank);
|
|
fft::Pass(shape, type, {8}, shape);
|
|
fft::Pass(shape, type, {16, 8}, shape);
|
|
fft::Fail(shape, type, {32, 16, 8}, fft::invalid_rank);
|
|
fft::Fail(shape, type, {64, 32, 16, 8}, fft::unsupported_rank);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestIfftTypes) {
|
|
FftType type = FftType::IFFT;
|
|
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
|
|
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
|
|
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
|
|
fft::Pass(shape_c128, type, {16, 8}, shape_c128);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestRfftRanks) {
|
|
FftType type = FftType::RFFT;
|
|
Shape shape_in = ShapeUtil::MakeShape(F32, {16, 8});
|
|
Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
|
|
fft::Fail(shape_in, type, {}, fft::unsupported_rank);
|
|
fft::Pass(shape_in, type, {8}, shape_out);
|
|
fft::Pass(shape_in, type, {16, 8}, shape_out);
|
|
fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
|
|
fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestRfftDimensions) {
|
|
FftType type = FftType::RFFT;
|
|
Shape shape = ShapeUtil::MakeShape(F32, {16, 8});
|
|
fft::Fail(shape, type, {4}, fft::dimensions_match);
|
|
fft::Fail(shape, type, {16, 4}, fft::dimensions_match);
|
|
fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
|
|
fft::Fail(shape, type, {8, 16}, fft::dimensions_match);
|
|
|
|
Shape zero_shape_in = ShapeUtil::MakeShape(F32, {16, 0});
|
|
Shape zero_shape_out = ShapeUtil::MakeShape(C64, {16, 0});
|
|
fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
|
|
fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
|
|
|
|
Shape even_shape_in = ShapeUtil::MakeShape(F32, {16, 8});
|
|
Shape odd_shape_in = ShapeUtil::MakeShape(F32, {16, 9});
|
|
Shape shape_out = ShapeUtil::MakeShape(C64, {16, 5});
|
|
fft::Pass(even_shape_in, type, {16, 8}, shape_out);
|
|
fft::Pass(odd_shape_in, type, {16, 9}, shape_out);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestRfftTypes) {
|
|
FftType type = FftType::RFFT;
|
|
Shape shape_c64 = ShapeUtil::MakeShape(C64, {16, 8});
|
|
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 8});
|
|
fft::Fail(shape_c64, type, {16, 8}, fft::requires_f32_input);
|
|
fft::Fail(shape_c128, type, {16, 8}, fft::requires_f32_input);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftRanks) {
|
|
FftType type = FftType::IRFFT;
|
|
Shape shape_in = ShapeUtil::MakeShape(C64, {16, 5});
|
|
Shape shape_out = ShapeUtil::MakeShape(F32, {16, 8});
|
|
fft::Fail(shape_in, type, {}, fft::unsupported_rank);
|
|
fft::Pass(shape_in, type, {8}, shape_out);
|
|
fft::Pass(shape_in, type, {16, 8}, shape_out);
|
|
fft::Fail(shape_in, type, {32, 16, 8}, fft::invalid_rank);
|
|
fft::Fail(shape_in, type, {64, 32, 16, 8}, fft::unsupported_rank);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftDimensions) {
|
|
FftType type = FftType::IRFFT;
|
|
Shape shape = ShapeUtil::MakeShape(C64, {16, 5});
|
|
fft::Fail(shape, type, {5}, fft::innermost_dimension_matches);
|
|
fft::Fail(shape, type, {16, 5}, fft::innermost_dimension_matches);
|
|
fft::Fail(shape, type, {8, 8}, fft::dimensions_match);
|
|
fft::Fail(shape, type, {8, 9}, fft::dimensions_match);
|
|
|
|
Shape zero_shape_in = ShapeUtil::MakeShape(C64, {16, 0});
|
|
Shape zero_shape_out = ShapeUtil::MakeShape(F32, {16, 0});
|
|
fft::Pass(zero_shape_in, type, {0}, zero_shape_out);
|
|
fft::Pass(zero_shape_in, type, {16, 0}, zero_shape_out);
|
|
|
|
Shape even_shape_out = ShapeUtil::MakeShape(F32, {16, 8});
|
|
Shape odd_shape_out = ShapeUtil::MakeShape(F32, {16, 9});
|
|
fft::Pass(shape, type, {16, 8}, even_shape_out);
|
|
fft::Pass(shape, type, {16, 9}, odd_shape_out);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferFftShapeTestIrfftTypes) {
|
|
FftType type = FftType::IRFFT;
|
|
Shape shape_f32 = ShapeUtil::MakeShape(F32, {16, 8});
|
|
Shape shape_c128 = ShapeUtil::MakeShape(C128, {16, 5});
|
|
Shape shape_f64_out = ShapeUtil::MakeShape(F64, {16, 8});
|
|
fft::Fail(shape_f32, type, {16, 8}, fft::requires_complex_input);
|
|
fft::Pass(shape_c128, type, {16, 8}, shape_f64_out);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, MapThatChangesElementType) {
|
|
Shape arg = ShapeUtil::MakeShape(F32, {20});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_);
|
|
auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
|
|
EXPECT_IS_OK(inferred_status.status());
|
|
Shape expected = ShapeUtil::MakeShape(S32, {20});
|
|
EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Map) {
|
|
auto inferred_status_r1f32 = ShapeInference::InferMapShape(
|
|
{&vector_32_, &vector_32_},
|
|
ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
|
|
EXPECT_IS_OK(inferred_status_r1f32.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie()));
|
|
|
|
// It's OK to provide a single argument, as long as the applied arity matches
|
|
// (this degenerates to a Map).
|
|
auto inferred_status_r1f32_one = ShapeInference::InferMapShape(
|
|
{&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0});
|
|
EXPECT_IS_OK(inferred_status_r1f32_one.status());
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie()));
|
|
|
|
auto inferred_status_r2s32 = ShapeInference::InferMapShape(
|
|
{&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_},
|
|
ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1});
|
|
EXPECT_IS_OK(inferred_status_r2s32.status());
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie()));
|
|
|
|
auto no_args_error = ShapeInference::InferMapShape(
|
|
{}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {});
|
|
ASSERT_FALSE(no_args_error.ok());
|
|
ASSERT_THAT(no_args_error.status().error_message(),
|
|
HasSubstr("expects at least one argument"));
|
|
|
|
auto args_diff_shapes_error = ShapeInference::InferMapShape(
|
|
{&vector_32_, &vector_64_},
|
|
ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
|
|
ASSERT_FALSE(args_diff_shapes_error.ok());
|
|
ASSERT_THAT(args_diff_shapes_error.status().error_message(),
|
|
HasSubstr("requires all operands to have the same shape"));
|
|
|
|
auto arity_error = ShapeInference::InferMapShape(
|
|
{&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_),
|
|
{0});
|
|
ASSERT_FALSE(arity_error.ok());
|
|
ASSERT_THAT(arity_error.status().error_message(),
|
|
HasSubstr("function arity must match"));
|
|
|
|
auto output_shape_error = ShapeInference::InferMapShape(
|
|
{&vector_32_, &vector_32_},
|
|
ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0});
|
|
ASSERT_FALSE(output_shape_error.ok());
|
|
ASSERT_THAT(output_shape_error.status().error_message(),
|
|
HasSubstr("result has to be a scalar"));
|
|
|
|
auto param_shape_error = ShapeInference::InferMapShape(
|
|
{&vector_32_, &vector_32_},
|
|
ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0});
|
|
ASSERT_FALSE(param_shape_error.ok());
|
|
ASSERT_THAT(param_shape_error.status().error_message(),
|
|
HasSubstr("parameter has to be a scalar"));
|
|
|
|
auto param_element_type_error = ShapeInference::InferMapShape(
|
|
{&vector_32_, &vector_32_},
|
|
ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0});
|
|
ASSERT_FALSE(param_element_type_error.ok());
|
|
ASSERT_THAT(param_element_type_error.status().error_message(),
|
|
HasSubstr("parameter type has to match argument"));
|
|
|
|
Shape arg = ShapeUtil::MakeShape(F32, {20});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_);
|
|
auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0});
|
|
EXPECT_IS_OK(inferred_status.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie()));
|
|
|
|
auto inferred_status_error1 = ShapeInference::InferMapShape(
|
|
{&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0});
|
|
ASSERT_FALSE(inferred_status_error1.ok());
|
|
ASSERT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("arity must match number of arguments"));
|
|
|
|
auto inferred_status_error2 = ShapeInference::InferMapShape(
|
|
{&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0});
|
|
ASSERT_FALSE(inferred_status_error2.ok());
|
|
ASSERT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("has to be a scalar"));
|
|
|
|
auto inferred_status_error3 = ShapeInference::InferMapShape(
|
|
{&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0});
|
|
ASSERT_FALSE(inferred_status_error3.ok());
|
|
ASSERT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("has to be a scalar"));
|
|
|
|
auto inferred_status_error5 = ShapeInference::InferMapShape(
|
|
{&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0});
|
|
ASSERT_FALSE(inferred_status_error5.ok());
|
|
ASSERT_THAT(inferred_status_error5.status().error_message(),
|
|
HasSubstr("parameter type has to match argument"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceVectorToScalar) {
|
|
ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {128}),
|
|
/*dimensions_to_reduce=*/{0});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstDimension) {
|
|
ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3, 4}),
|
|
ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{0});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongMiddleDimension) {
|
|
ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2, 4}),
|
|
ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{1});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstTwoDimensions) {
|
|
ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {4}),
|
|
ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{0, 1});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongLastTwoDimensions) {
|
|
ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {2}),
|
|
ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{1, 2});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongFirstAndLastDimensions) {
|
|
ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
|
|
ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{0, 2});
|
|
|
|
// Check that the order of dimensions_to_reduce doesn't matter.
|
|
ExpectInferredReduceShape(ShapeUtil::MakeShape(F32, {3}),
|
|
ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{2, 0});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceCubeAmongAllDimensions) {
|
|
ExpectInferredReduceShape(f32_, ShapeUtil::MakeShape(F32, {2, 3, 4}),
|
|
/*dimensions_to_reduce=*/{0, 1, 2});
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceMultiOutput) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
|
|
EXPECT_IS_OK(inferred_status.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeTupleShape({f32_, s32_}),
|
|
inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceWindowMultiOutput) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
|
|
std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
|
|
std::vector<const Shape*> inits = {&f32_, &s32_};
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
|
std::vector<int64> window_dimensions = {1, 2, 4};
|
|
std::vector<int64> window_strides = {1, 1, 1};
|
|
std::vector<std::pair<int64, int64>> padding_values =
|
|
MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
|
|
window_strides, Padding::kValid);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Window window,
|
|
ShapeInference::InferWindowFromDimensions(
|
|
window_dimensions, window_strides, padding_values, {}, {}));
|
|
auto inferred_status = ShapeInference::InferReduceWindowShape(
|
|
absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
|
|
VLOG(2) << inferred_status.ValueOrDie().ToString() << "\n";
|
|
EXPECT_IS_OK(inferred_status.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {5, 2, 0}),
|
|
ShapeUtil::MakeShape(S32, {5, 2, 0})}),
|
|
inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput1) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
|
ProgramShape to_apply =
|
|
ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_, f32_, s32_},
|
|
ShapeUtil::MakeTupleShape({f32_, s32_}));
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("must take 4 parameters, but takes 6 parameter(s)"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput2) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(
|
|
inferred_status.status().error_message(),
|
|
HasSubstr(
|
|
"parameter shape differs from the result shape: s32[] vs f32[]"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerInput3) {
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{s32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
|
auto inferred_status = ShapeInference::InferReduceShape({}, {0, 1}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("must have at least 2 arguments, has 0"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorBadReduceWindowInput) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3, 1});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3, 1});
|
|
std::vector<const Shape*> args = {&f32_arg_shape, &s32_arg_shape};
|
|
std::vector<const Shape*> inits = {&f32_, &s32_};
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{f32_, f32_, f32_, f32_}, ShapeUtil::MakeTupleShape({f32_, s32_}));
|
|
std::vector<int64> window_dimensions = {1, 2, 4};
|
|
std::vector<int64> window_strides = {1, 1, 1};
|
|
std::vector<std::pair<int64, int64>> padding_values =
|
|
MakePadding(AsInt64Slice(f32_arg_shape.dimensions()), window_dimensions,
|
|
window_strides, Padding::kValid);
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Window window,
|
|
ShapeInference::InferWindowFromDimensions(
|
|
window_dimensions, window_strides, padding_values, {}, {}));
|
|
auto inferred_status = ShapeInference::InferReduceWindowShape(
|
|
absl::MakeSpan(args), absl::MakeSpan(inits), window, to_apply);
|
|
EXPECT_FALSE(inferred_status.status().ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("f32[] vs s32[]"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput1) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
|
ProgramShape to_apply =
|
|
ShapeUtil::MakeProgramShape({f32_, s32_, f32_, s32_}, f32_);
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(
|
|
inferred_status.status().error_message(),
|
|
HasSubstr("must produce a tuple with 2 elements, but produces a scalar"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerOutput2) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{f32_, s32_, f32_, s32_}, ShapeUtil::MakeTupleShape({f32_, s32_, s32_}));
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(
|
|
inferred_status.status().error_message(),
|
|
HasSubstr("must produce a tuple with 2 elements, but has 3 elements"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorMultiOutputBadReducerBoth) {
|
|
Shape f32_arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape s32_arg_shape = ShapeUtil::MakeShape(S32, {5, 3});
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape(
|
|
{s32_, s32_, s32_, s32_}, ShapeUtil::MakeTupleShape({s32_, s32_}));
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&f32_arg_shape, &s32_arg_shape, &f32_, &s32_}, {0, 1}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("accumulator shape at index 0 differs from the "
|
|
"init_value shape: s32[] vs f32[]"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorOutOfBoundsDimension) {
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
|
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&arg_shape, &f32_},
|
|
/*dimensions_to_reduce=*/{3, 4}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("out-of-bounds dimension"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorToApplyArity) {
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_, f32_}, f32_);
|
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
auto inferred_status =
|
|
ShapeInference::InferReduceShape({&arg_shape, &f32_},
|
|
/*dimensions_to_reduce=*/{0}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("take 2 parameters"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ErrorElementTypeVsApplyType) {
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, s32_);
|
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
auto inferred_status =
|
|
ShapeInference::InferReduceShape({&arg_shape, &f32_},
|
|
/*dimensions_to_reduce=*/{0}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("0-th parameter shape differs"));
|
|
}
|
|
|
|
TEST_F(ReduceShapeInferenceTest, ReduceWithRepeatedReduceDimension) {
|
|
ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
|
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
auto inferred_status = ShapeInference::InferReduceShape(
|
|
{&arg_shape, &f32_},
|
|
/*dimensions_to_reduce=*/{0, 0}, to_apply);
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Duplicate reduction dimension: 0"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferSliceShapeRank2) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {1, 1});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 64}), inferred));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferSliceWithDynamicDimensions) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64}, {true, true});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {33, 64}, {1, 1});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(
|
|
ShapeUtil::MakeShape(F32, {1, 64}, {false, true}), inferred));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStrides) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(matrix_shape, {32, 0}, {64, 64}, {2, 4});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {16, 16}), inferred));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferSliceShapeRank2WithStridesNotIntegral) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(matrix_shape, {15, 0}, {20, 13}, {2, 4});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {3, 4}), inferred));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferInvalidStride) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {0, 1});
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
|
|
inferred_status.status().code());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferOobSliceShapeRank2) {
|
|
Shape matrix_shape = ShapeUtil::MakeShape(F32, {128, 64});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(matrix_shape, {127, 0}, {129, 2}, {1, 1});
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_EQ(tensorflow::error::INVALID_ARGUMENT,
|
|
inferred_status.status().code());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferSliceShapeRank1) {
|
|
Shape vector_shape = ShapeUtil::MakeShape(F32, {17});
|
|
auto inferred_status =
|
|
ShapeInference::InferSliceShape(vector_shape, {2}, {4}, {1});
|
|
ASSERT_TRUE(inferred_status.ok());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred, ShapeUtil::MakeShape(F32, {2})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferConstIndexShape) {
|
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
|
|
auto inferred0_status =
|
|
ShapeInference::InferGetTupleElementShape(tuple_shape, 0);
|
|
auto inferred1_status =
|
|
ShapeInference::InferGetTupleElementShape(tuple_shape, 1);
|
|
ASSERT_IS_OK(inferred0_status.status());
|
|
ASSERT_IS_OK(inferred1_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred0_status.ValueOrDie()));
|
|
ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
|
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
|
|
auto inferredNegative_status =
|
|
ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
|
|
auto inferred2_status =
|
|
ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
|
|
ASSERT_FALSE(inferredNegative_status.ok());
|
|
ASSERT_FALSE(inferred2_status.ok());
|
|
EXPECT_THAT(inferredNegative_status.status().error_message(),
|
|
HasSubstr("attempt to index out of tuple bounds"));
|
|
EXPECT_THAT(inferred2_status.status().error_message(),
|
|
HasSubstr("attempt to index out of tuple bounds"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferPowShape) {
|
|
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
auto inferred_status = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kPower, ten_floats, f32_, {});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(ten_floats, inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferCompareShape) {
|
|
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
|
|
auto inferred_status = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kCompare, ten_floats, f32_, {});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(PRED, {10}),
|
|
inferred_status.ValueOrDie()));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferReshapeDegenerateCombine) {
|
|
// [1, <=1]
|
|
// | reshape
|
|
// [<=1]
|
|
//
|
|
// Both output dimension can be dynamic, use inferred_dimension to tie-break.
|
|
auto operand = ShapeUtil::MakeShape(F32, {1, 1}, {false, true});
|
|
auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {1},
|
|
/*inferred_dimension=*/-1);
|
|
ASSERT_EQ(ShapeUtil::MakeShape(F32, {1}, {true}), status.ValueOrDie());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferReshapeSplit) {
|
|
// [<=10]
|
|
// | reshape
|
|
// [1, 10]
|
|
//
|
|
// Both output dimension can be dynamic, use inferred_dimension to tie-break.
|
|
auto operand = ShapeUtil::MakeShape(F32, {10}, {true});
|
|
auto status = ShapeInference::InferReshapeShape(operand, {0}, {1, 10},
|
|
/*inferred_dimension=*/0);
|
|
ASSERT_EQ(ShapeUtil::MakeShape(F32, {1, 10}, {true, false}),
|
|
status.ValueOrDie());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferReshapeCombine) {
|
|
// [6, <=10]
|
|
// | reshape
|
|
// [<=60]
|
|
auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
|
|
auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {60},
|
|
/*inferred_dimension=*/-11);
|
|
ASSERT_EQ(ShapeUtil::MakeShape(F32, {60}, {true}), status.ValueOrDie());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, UnchangedDimension) {
|
|
// [6, <=10]
|
|
// | reshape
|
|
// [2, 3, <=10]
|
|
auto operand = ShapeUtil::MakeShape(F32, {6, 10}, {false, true});
|
|
auto status = ShapeInference::InferReshapeShape(operand, {1, 0}, {2, 3, 10},
|
|
/*inferred_dimension=*/-11);
|
|
ASSERT_EQ(ShapeUtil::MakeShape(F32, {2, 3, 10}, {false, false, true}),
|
|
status.ValueOrDie());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, InferDynamicBroadcast) {
|
|
// CHECK:
|
|
// %broadcast = s32[15,<=15]{1,0} broadcast(s32[<=15]{0}), dimensions={1}
|
|
|
|
auto operand_shape = ShapeUtil::MakeShape(F32, {15}, {true});
|
|
auto inferred_status =
|
|
ShapeInference::InferBroadcastShape(operand_shape, {15});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_EQ(ShapeUtil::MakeShape(F32, {15, 15}, {false, true}), inferred);
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BroadcastScalar) {
|
|
for (auto element_type : {F32, U32, S8}) {
|
|
const Shape scalar_shape = ShapeUtil::MakeShape(element_type, {});
|
|
{ // no-op scalar broadcast
|
|
auto status = ShapeInference::InferBroadcastShape(scalar_shape, {});
|
|
ASSERT_IS_OK(status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(scalar_shape, status.ValueOrDie()));
|
|
}
|
|
const Shape oned_shape = ShapeUtil::MakeShape(element_type, {3});
|
|
{ // scalar -> 1d broadcast
|
|
auto status = ShapeInference::InferBroadcastShape(scalar_shape, {3});
|
|
ASSERT_IS_OK(status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
|
|
}
|
|
{ // no-op 1d broadcast
|
|
auto status = ShapeInference::InferBroadcastShape(oned_shape, {});
|
|
ASSERT_IS_OK(status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(oned_shape, status.ValueOrDie()));
|
|
}
|
|
const Shape twod_shape = ShapeUtil::MakeShape(element_type, {2, 3});
|
|
{ // scalar -> 2d broadcast
|
|
auto status = ShapeInference::InferBroadcastShape(scalar_shape, {2, 3});
|
|
ASSERT_IS_OK(status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
|
|
}
|
|
{ // 1d -> 2d broadcast
|
|
auto status = ShapeInference::InferBroadcastShape(oned_shape, {2});
|
|
ASSERT_IS_OK(status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(twod_shape, status.ValueOrDie()));
|
|
}
|
|
}
|
|
}
|
|
|
|
// scalar <dot> vector: ok
|
|
TEST_F(ShapeInferenceTest, ScalarDotVector) {
|
|
DotDimensionNumbers dot_dnums;
|
|
auto inferred_status = ShapeInference::InferDotOpShape(
|
|
f32_, vector_32_, dot_dnums, /*preferred_element_type=*/absl::nullopt);
|
|
EXPECT_TRUE(inferred_status.ok());
|
|
EXPECT_EQ(inferred_status.ValueOrDie(), vector_32_);
|
|
}
|
|
|
|
// 3D <dot> 2D: error
|
|
TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status = ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
EXPECT_TRUE(inferred_status.ok());
|
|
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
|
|
ShapeUtil::MakeShape(F32, {32, 32, 64})));
|
|
}
|
|
|
|
// vector <dot> vector -> scalar
|
|
TEST_F(ShapeInferenceTest, VectorDotVector) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(0);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
|
|
auto inferred_status_mismatch =
|
|
ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status_mismatch.ok());
|
|
}
|
|
|
|
// matrix <dot> vector -> vector
|
|
TEST_F(ShapeInferenceTest, MatrixDotVector) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
|
|
auto inferred_status_mismatch =
|
|
ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status_mismatch.ok());
|
|
}
|
|
|
|
// vector <dot> matrix -> vector
|
|
TEST_F(ShapeInferenceTest, VectorDotMatrix) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(0);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
|
|
auto inferred_status_mismatch =
|
|
ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status_mismatch.ok());
|
|
}
|
|
|
|
// matrix <dot> matrix -> matrix
|
|
TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status_match =
|
|
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(
|
|
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
|
|
<< "inferred: "
|
|
<< ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
|
|
<< " expected: " << ShapeUtil::HumanString(matrix_64_48_);
|
|
auto inferred_status_mismatch =
|
|
ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status_mismatch.ok());
|
|
}
|
|
|
|
// BatchMatMul with two batch dimensions and one contracting dimension.
|
|
TEST_F(ShapeInferenceTest, DotGeneral) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
|
|
Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(3);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
dot_dnums.add_lhs_batch_dimensions(1);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(2);
|
|
dot_dnums.add_rhs_batch_dimensions(0);
|
|
dot_dnums.add_rhs_batch_dimensions(1);
|
|
|
|
auto inferred_status_match =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(
|
|
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
|
|
<< "inferred: "
|
|
<< ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
|
|
<< " expected: " << ShapeUtil::HumanString(output_shape);
|
|
}
|
|
|
|
// BatchMatMul with two contracting dimensions fails.
|
|
TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(2);
|
|
dot_dnums.add_lhs_contracting_dimensions(3);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_batch_dimensions(0);
|
|
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Must specify the same number of contracting "
|
|
"dimensions for lhs and rhs."));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsPasses) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 2, 14});
|
|
Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(2);
|
|
dot_dnums.add_lhs_contracting_dimensions(3);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(2);
|
|
dot_dnums.add_rhs_batch_dimensions(0);
|
|
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
EXPECT_TRUE(inferred_status.ok());
|
|
EXPECT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), output_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ErrorSetDimensionSize) {
|
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape val_shape = ShapeUtil::MakeShape(S32, {1});
|
|
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
|
|
arg_shape, val_shape, /*dimension=*/0);
|
|
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("value has to be S32 scalar"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ErrorSetDimensionSizeWrongType) {
|
|
Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3});
|
|
Shape val_shape = ShapeUtil::MakeShape(U32, {});
|
|
auto inferred_status = ShapeInference::InferSetDimensionSizeShape(
|
|
arg_shape, val_shape, /*dimension=*/0);
|
|
|
|
EXPECT_FALSE(inferred_status.ok());
|
|
EXPECT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("value has to be S32 scalar"));
|
|
}
|
|
|
|
// BatchMatMul with different batch dimension sizes fails.
|
|
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimSizesFails) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(2);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_batch_dimensions(0);
|
|
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("Batch dimension sizes must match"));
|
|
}
|
|
|
|
// BatchMatMul with different batch dimension numbers passes
|
|
TEST_F(ShapeInferenceTest, DotWithMismatchedBatchDimNumbersPasses) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(2);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
dot_dnums.add_rhs_batch_dimensions(1);
|
|
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_TRUE(inferred_status.ok());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(),
|
|
ShapeUtil::MakeShape(F32, {2, 11, 14})));
|
|
}
|
|
|
|
// BatchMatMul with out-of-range dimension numbers fails.
|
|
TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(3);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
dot_dnums.add_rhs_batch_dimensions(1);
|
|
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("A dimension number is out of range"));
|
|
}
|
|
|
|
// BatchMatMul with non-unique dimension numbers fails.
|
|
TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
|
|
Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
|
|
Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
|
|
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(0);
|
|
dot_dnums.add_lhs_batch_dimensions(0);
|
|
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
dot_dnums.add_rhs_batch_dimensions(1);
|
|
|
|
auto inferred_status =
|
|
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums,
|
|
/*preferred_element_type=*/absl::nullopt);
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.status().error_message(),
|
|
HasSubstr("A dimension number is not unique"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, DotWithIntegralPreferredElementType) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
|
|
ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(S8, {32, 32}),
|
|
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/S32));
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(S32, {32, 32})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeSameAsInferredType) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
|
|
ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(BF16, {32, 32}),
|
|
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/F32));
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(F32, {32, 32})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, FloatingPointDotWithNarrowerPreferredElementType) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape inferred_shape,
|
|
ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(BF16, {32, 32}),
|
|
ShapeUtil::MakeShape(F32, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/BF16));
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(inferred_shape, ShapeUtil::MakeShape(BF16, {32, 32})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, FloatingPointDotWithInvalidPreferredElementType) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status = ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(BF16, {32, 32}),
|
|
ShapeUtil::MakeShape(BF16, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/S32)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must both be integral or both be floating point"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, IntegralDotWithFloatingPointPreferredElementType) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status = ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(S8, {32, 32}),
|
|
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/F32)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must both be integral or both be floating point"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, DotWithPreferredElementTypeWithDifferentSignedness) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status = ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(S8, {32, 32}),
|
|
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/U32)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must have the same signedness as the original type"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, DotWithNarrowerPreferredElementType) {
|
|
DotDimensionNumbers dot_dnums;
|
|
dot_dnums.add_lhs_contracting_dimensions(1);
|
|
dot_dnums.add_rhs_contracting_dimensions(0);
|
|
auto inferred_status = ShapeInference::InferDotOpShape(
|
|
ShapeUtil::MakeShape(S8, {32, 32}),
|
|
ShapeUtil::MakeShape(S16, {32, 32}), dot_dnums,
|
|
/*preferred_element_type=*/S8)
|
|
.status();
|
|
ASSERT_FALSE(inferred_status.ok());
|
|
ASSERT_THAT(inferred_status.error_message(),
|
|
HasSubstr("must not be narrower than the original type"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
|
|
// Test variations of broadcasting a vector for a binary add with a
|
|
// matrix.
|
|
const Shape mat = ShapeUtil::MakeShape(F32, {16, 8});
|
|
const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
|
|
const Shape vec16 = ShapeUtil::MakeShape(F32, {16});
|
|
|
|
auto inferred_status_match =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {1});
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
|
|
|
|
auto inferred_status_mismatch =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec8, {0});
|
|
ASSERT_FALSE(inferred_status_mismatch.ok());
|
|
|
|
inferred_status_match =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {0});
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), mat));
|
|
|
|
inferred_status_mismatch =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, mat, vec16, {1});
|
|
ASSERT_FALSE(inferred_status_mismatch.ok());
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BinOpBroadcastCubeMatrix) {
|
|
// Test variations of broadcasting a matrix for a binary add with a cube.
|
|
const Shape cube = ShapeUtil::MakeShape(F32, {16, 8, 4});
|
|
const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
|
|
const Shape matrix16_4 = ShapeUtil::MakeShape(F32, {16, 4});
|
|
const Shape matrix16_8 = ShapeUtil::MakeShape(F32, {16, 8});
|
|
|
|
auto inferred_status_match = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, cube, matrix8_4, {1, 2});
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
|
|
|
|
inferred_status_match = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, cube, matrix16_4, {0, 2});
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
|
|
|
|
inferred_status_match = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, cube, matrix16_8, {0, 1});
|
|
ASSERT_IS_OK(inferred_status_match.status());
|
|
ASSERT_TRUE(ShapeUtil::Equal(inferred_status_match.ValueOrDie(), cube));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) {
|
|
// Test various errors with the broadcast argument.
|
|
const Shape tensor = ShapeUtil::MakeShape(F32, {16, 8, 4});
|
|
const Shape tensor8_8_8 = ShapeUtil::MakeShape(F32, {8, 8, 8});
|
|
const Shape vec8 = ShapeUtil::MakeShape(F32, {8});
|
|
const Shape matrix8_4 = ShapeUtil::MakeShape(F32, {8, 4});
|
|
const Shape matrix8_8 = ShapeUtil::MakeShape(F32, {8, 8});
|
|
|
|
// "magical" broadcast rejected
|
|
auto inferred_status_error1 =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {});
|
|
ASSERT_FALSE(inferred_status_error1.ok());
|
|
ASSERT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("Automatic"));
|
|
|
|
// broadcast_dimension out of bounds for tensor's rank
|
|
auto inferred_status_error2 =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {3});
|
|
ASSERT_FALSE(inferred_status_error2.ok());
|
|
ASSERT_THAT(inferred_status_error2.status().error_message(),
|
|
ContainsRegex("Broadcast dimension number .* too large"));
|
|
|
|
// broadcast_dimension doesn't match corresponding dimension
|
|
auto inferred_status_error3 =
|
|
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, tensor, vec8, {0});
|
|
ASSERT_FALSE(inferred_status_error3.ok());
|
|
ASSERT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("Broadcast dimension 0 mismatch"));
|
|
|
|
// broadcast_dimensions list too long
|
|
auto inferred_status_error4 = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, tensor, matrix8_4, {0, 1, 2});
|
|
ASSERT_FALSE(inferred_status_error4.ok());
|
|
ASSERT_THAT(inferred_status_error4.status().error_message(),
|
|
HasSubstr("broadcast_dimensions has to match"));
|
|
|
|
// there's a dimension above the rank of the tensor
|
|
auto inferred_status_error5 = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, tensor, matrix8_4, {3, 0});
|
|
ASSERT_FALSE(inferred_status_error5.ok());
|
|
ASSERT_THAT(inferred_status_error5.status().error_message(),
|
|
ContainsRegex("dimension number .* too large"));
|
|
|
|
// broadcasting dimensions don't match in this order
|
|
auto inferred_status_error6 = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, tensor, matrix8_4, {2, 1});
|
|
ASSERT_FALSE(inferred_status_error6.ok());
|
|
ASSERT_THAT(inferred_status_error6.status().error_message(),
|
|
HasSubstr("dimension 0 mismatch"));
|
|
|
|
// The following two tests make sure that broadcasting dimensions are listed
|
|
// in a proper (strictly increasing) order, even if the lower-rank array
|
|
// matches the higher-rank array in many different ways.
|
|
auto inferred_status_error7 = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, tensor8_8_8, matrix8_8, {0, 0});
|
|
ASSERT_FALSE(inferred_status_error7.ok());
|
|
ASSERT_THAT(inferred_status_error7.status().error_message(),
|
|
HasSubstr("dimensions order is wrong"));
|
|
|
|
auto inferred_status_error8 = ShapeInference::InferBinaryOpShape(
|
|
HloOpcode::kAdd, tensor8_8_8, matrix8_8, {1, 0});
|
|
ASSERT_FALSE(inferred_status_error8.ok());
|
|
ASSERT_THAT(inferred_status_error8.status().error_message(),
|
|
HasSubstr("dimensions order is wrong"));
|
|
}
|
|
|
|
// Tests for the while instruction with proper shapes.
|
|
TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) {
|
|
Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
|
|
ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
|
|
ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
|
|
auto inferred_status =
|
|
ShapeInference::InferWhileShape(cond, body, result_shape);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(result_shape, inferred));
|
|
}
|
|
|
|
// Tests for the while instruction with wrong shapes.
|
|
TEST_F(ShapeInferenceTest, WhileWithBadShapes) {
|
|
Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});
|
|
ProgramShape cond = ShapeUtil::MakeProgramShape({result_shape}, pred_);
|
|
ProgramShape body = ShapeUtil::MakeProgramShape({result_shape}, result_shape);
|
|
|
|
auto bad_shape_1 = ShapeUtil::MakeProgramShape({s32_, result_shape}, pred_);
|
|
auto inferred_status_error1 =
|
|
ShapeInference::InferWhileShape(bad_shape_1, body, result_shape);
|
|
ASSERT_FALSE(inferred_status_error1.ok());
|
|
ASSERT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("Condition must take 1 arguments"));
|
|
|
|
auto bad_shape_2 =
|
|
ShapeUtil::MakeProgramShape({s32_, result_shape}, result_shape);
|
|
auto inferred_status_error2 =
|
|
ShapeInference::InferWhileShape(cond, bad_shape_2, result_shape);
|
|
ASSERT_FALSE(inferred_status_error2.ok());
|
|
ASSERT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("Body must take 1 arguments"));
|
|
|
|
auto bad_shape_3 = ShapeUtil::MakeProgramShape({result_shape}, s32_);
|
|
auto inferred_status_error3 =
|
|
ShapeInference::InferWhileShape(bad_shape_3, body, result_shape);
|
|
ASSERT_FALSE(inferred_status_error3.ok());
|
|
ASSERT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("Condition must return a boolean"));
|
|
|
|
auto bad_shape_4 = ShapeUtil::MakeProgramShape({result_shape}, vector_32_);
|
|
auto inferred_status_error4 =
|
|
ShapeInference::InferWhileShape(cond, bad_shape_4, result_shape);
|
|
ASSERT_FALSE(inferred_status_error4.ok());
|
|
ASSERT_THAT(inferred_status_error4.status().error_message(),
|
|
HasSubstr("parameter of condition and body"));
|
|
}
|
|
|
|
// Tests for the concatenate instruction with dynamic shapes.
|
|
TEST_F(ShapeInferenceTest, ConcatenateWithDynamicShapes) {
|
|
auto dynamic_shape_1 =
|
|
ShapeUtil::MakeShape(F32, {32, 160, 10}, {true, false, false});
|
|
auto dynamic_shape_2 =
|
|
ShapeUtil::MakeShape(F32, {32, 160, 10}, {false, true, false});
|
|
auto inferred_status = ShapeInference::InferConcatOpShape(
|
|
{&dynamic_shape_1, &dynamic_shape_2}, /*dimension=*/0);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(
|
|
ShapeUtil::MakeShape(F32, {64, 160, 10}, {true, true, false}), inferred));
|
|
}
|
|
|
|
// Tests for the concatenate instruction with proper shapes.
|
|
TEST_F(ShapeInferenceTest, ConcatenateWithCorrectShapes) {
|
|
auto inferred_status_1 = ShapeInference::InferConcatOpShape(
|
|
{&vector_32_, &vector_64_}, /*dimension=*/0);
|
|
ASSERT_IS_OK(inferred_status_1.status());
|
|
Shape inferred_1 = inferred_status_1.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {96}), inferred_1));
|
|
|
|
auto inferred_status_2 = ShapeInference::InferConcatOpShape(
|
|
{&vector_32_, &vector_64_, &vector_32_}, /*dimension=*/0);
|
|
ASSERT_IS_OK(inferred_status_2.status());
|
|
Shape inferred_2 = inferred_status_2.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {128}), inferred_2));
|
|
|
|
auto inferred_status_3 = ShapeInference::InferConcatOpShape(
|
|
{&matrix_32_48_, &matrix_32_64_, &matrix_32_48_}, /*dimension=*/1);
|
|
ASSERT_IS_OK(inferred_status_3.status());
|
|
Shape inferred_3 = inferred_status_3.ValueOrDie();
|
|
ASSERT_TRUE(
|
|
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {32, 160}), inferred_3));
|
|
}
|
|
|
|
// Tests for the concatenate instruction with wrong shapes.
|
|
TEST_F(ShapeInferenceTest, ConcatenateWithBadShapes) {
|
|
auto inferred_status_error1 =
|
|
ShapeInference::InferConcatOpShape({}, /*dimension=*/0);
|
|
ASSERT_FALSE(inferred_status_error1.ok());
|
|
ASSERT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("Concatenate expects at least one argument"));
|
|
|
|
auto inferred_status_error2 =
|
|
ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/-1);
|
|
ASSERT_FALSE(inferred_status_error2.ok());
|
|
ASSERT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("dimension out of bounds: -1"));
|
|
|
|
auto inferred_status_error3 =
|
|
ShapeInference::InferConcatOpShape({&vector_32_}, /*dimension=*/1);
|
|
ASSERT_FALSE(inferred_status_error3.ok());
|
|
ASSERT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("dimension out of bounds: 1"));
|
|
|
|
Shape tuple = ShapeUtil::MakeTupleShape({vector_32_});
|
|
auto inferred_status_error4 = ShapeInference::InferConcatOpShape(
|
|
{&vector_32_, &tuple}, /*dimension=*/0);
|
|
ASSERT_FALSE(inferred_status_error4.ok());
|
|
ASSERT_THAT(
|
|
inferred_status_error4.status().error_message(),
|
|
HasSubstr("Expected array argument for operand of concatenation"));
|
|
|
|
const Shape vector_s32 = ShapeUtil::MakeShape(S32, {32});
|
|
auto inferred_status_error5 = ShapeInference::InferConcatOpShape(
|
|
{&vector_32_, &vector_s32}, /*dimension=*/0);
|
|
ASSERT_FALSE(inferred_status_error5.ok());
|
|
ASSERT_THAT(inferred_status_error5.status().error_message(),
|
|
HasSubstr("concatenate arrays with different element types"));
|
|
|
|
auto inferred_status_error6 = ShapeInference::InferConcatOpShape(
|
|
{&matrix_32_48_, &matrix_32_64_}, /*dimension=*/0);
|
|
ASSERT_FALSE(inferred_status_error6.ok());
|
|
ASSERT_THAT(inferred_status_error6.status().error_message(),
|
|
HasSubstr("concatenate arrays that differ in "
|
|
"dimensions other than the one being "
|
|
"concatenated"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Pad) {
|
|
Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
|
|
Shape padding_value_shape = ShapeUtil::MakeShape(F32, {});
|
|
// Padding for dimension 0: {low: 0, high: 2, interior: 3}
|
|
// Padding for dimension 1: {low: 1, high: 5, interior: 0}
|
|
PaddingConfig padding_config;
|
|
auto dimension0 = padding_config.add_dimensions();
|
|
dimension0->set_edge_padding_low(0);
|
|
dimension0->set_edge_padding_high(2);
|
|
dimension0->set_interior_padding(3);
|
|
auto dimension1 = padding_config.add_dimensions();
|
|
dimension1->set_edge_padding_low(1);
|
|
dimension1->set_edge_padding_high(5);
|
|
dimension1->set_interior_padding(0);
|
|
|
|
auto inferred_status = ShapeInference::InferPadShape(
|
|
input_shape, padding_value_shape, padding_config);
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred_shape = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(
|
|
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
|
|
|
|
dimension1->set_edge_padding_low(-20);
|
|
dimension1->set_edge_padding_high(-10);
|
|
auto negative_dimension_size = ShapeInference::InferPadShape(
|
|
input_shape, padding_value_shape, padding_config);
|
|
ASSERT_FALSE(negative_dimension_size.ok());
|
|
ASSERT_THAT(negative_dimension_size.status().error_message(),
|
|
HasSubstr("negative size for dimension 1"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Reverse) {
|
|
Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
|
|
|
|
auto inferred_status = ShapeInference::InferReverseShape(input_shape, {0, 1});
|
|
ASSERT_IS_OK(inferred_status.status());
|
|
Shape inferred_shape = inferred_status.ValueOrDie();
|
|
ASSERT_TRUE(ShapeUtil::Equal(input_shape, inferred_shape));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ReverseInvalidDimension) {
|
|
Shape input_shape = ShapeUtil::MakeShape(F32, {10, 25});
|
|
|
|
auto inferred_status_error0 =
|
|
ShapeInference::InferReverseShape(input_shape, {0, 2});
|
|
ASSERT_FALSE(inferred_status_error0.ok());
|
|
ASSERT_THAT(inferred_status_error0.status().error_message(),
|
|
HasSubstr("out-of-bounds"));
|
|
|
|
auto inferred_status_error1 =
|
|
ShapeInference::InferReverseShape(input_shape, {0, -1});
|
|
ASSERT_FALSE(inferred_status_error1.ok());
|
|
ASSERT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("out-of-bounds"));
|
|
|
|
auto inferred_status_error2 =
|
|
ShapeInference::InferReverseShape(input_shape, {0, 0});
|
|
ASSERT_FALSE(inferred_status_error2.ok());
|
|
ASSERT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("duplicated"));
|
|
|
|
Shape tuple_shape = ShapeUtil::MakeTupleShape({input_shape, input_shape});
|
|
auto inferred_status_error3 =
|
|
ShapeInference::InferReverseShape(tuple_shape, {0});
|
|
ASSERT_FALSE(inferred_status_error3.ok());
|
|
ASSERT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("Expected array argument"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Call) {
|
|
auto inferred_status0 =
|
|
ShapeInference::InferCallShape({}, ShapeUtil::MakeProgramShape({}, f32_));
|
|
EXPECT_IS_OK(inferred_status0.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
|
|
|
|
auto inferred_status1 = ShapeInference::InferCallShape(
|
|
{&f32_, &s32_, &pred_, &vector_32_, &matrix_32_48_},
|
|
ShapeUtil::MakeProgramShape(
|
|
{f32_, s32_, pred_, vector_32_, matrix_32_48_}, s32matrix_64_64_));
|
|
EXPECT_IS_OK(inferred_status1.status());
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(s32matrix_64_64_, inferred_status1.ValueOrDie()));
|
|
|
|
auto inferred_status_error0 = ShapeInference::InferCallShape(
|
|
{}, ShapeUtil::MakeProgramShape({f32_}, f32_));
|
|
EXPECT_FALSE(inferred_status_error0.ok());
|
|
EXPECT_THAT(inferred_status_error0.status().error_message(),
|
|
HasSubstr("arity must match"));
|
|
|
|
auto inferred_status_error1 = ShapeInference::InferCallShape(
|
|
{&f32_}, ShapeUtil::MakeProgramShape({}, f32_));
|
|
EXPECT_FALSE(inferred_status_error1.ok());
|
|
EXPECT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("arity must match"));
|
|
|
|
auto inferred_status_error2 = ShapeInference::InferCallShape(
|
|
{&f32_}, ShapeUtil::MakeProgramShape({s32_}, f32_));
|
|
EXPECT_FALSE(inferred_status_error2.ok());
|
|
EXPECT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("parameter must match argument"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Transpose) {
|
|
Shape a_shape = ShapeUtil::MakeShape(F32, {2, 3, 4, 5});
|
|
auto inferred_shape_and_status =
|
|
ShapeInference::InferTransposeShape(a_shape, {1, 2, 3, 0});
|
|
EXPECT_IS_OK(inferred_shape_and_status);
|
|
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
|
EXPECT_TRUE(ShapeUtil::Compatible(inferred_shape,
|
|
ShapeUtil::MakeShape(F32, {3, 4, 5, 2})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, Rank1Transpose) {
|
|
Shape a_shape = ShapeUtil::MakeShape(F32, {5});
|
|
auto inferred_shape_and_status =
|
|
ShapeInference::InferTransposeShape(a_shape, {0});
|
|
EXPECT_IS_OK(inferred_shape_and_status);
|
|
Shape inferred_shape = inferred_shape_and_status.ValueOrDie();
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Compatible(inferred_shape, ShapeUtil::MakeShape(F32, {5})));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConditionalPred) {
|
|
auto inferred_status0 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
|
|
{vector_32_, vector_64_});
|
|
EXPECT_IS_OK(inferred_status0.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
|
|
|
|
auto inferred_status1 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)},
|
|
{matrix_32_48_, vector_32_});
|
|
EXPECT_IS_OK(inferred_status1.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
|
|
|
|
auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
|
|
auto inferred_status2 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
|
|
ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
|
|
{matrix_32_48_, tuple_f32_v32});
|
|
EXPECT_IS_OK(inferred_status2.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
|
|
|
|
auto inferred_status_error0 = ShapeInference::InferConditionalShape(
|
|
f32_,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
|
|
{vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error0.ok());
|
|
EXPECT_THAT(inferred_status_error0.status().error_message(),
|
|
HasSubstr("must be bool or int32"));
|
|
|
|
auto inferred_status_error1 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
|
|
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
|
|
{ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_});
|
|
EXPECT_FALSE(inferred_status_error1.ok());
|
|
EXPECT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("branch computation 0 must take 1 argument"));
|
|
|
|
auto inferred_status_error2 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({vector_64_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
|
|
{vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error2.ok());
|
|
EXPECT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("branch operand 0 must match the shape of the only "
|
|
"parameter of branch computation 0"));
|
|
|
|
auto inferred_status_error3 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
|
|
ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)},
|
|
{matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_})});
|
|
EXPECT_FALSE(inferred_status_error3.ok());
|
|
EXPECT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("branch computation 1 must take 1 argument"));
|
|
|
|
auto inferred_status_error4 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
|
|
{vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error4.ok());
|
|
EXPECT_THAT(inferred_status_error4.status().error_message(),
|
|
HasSubstr("branch operand 1 must match the shape of the only "
|
|
"parameter of branch computation 1"));
|
|
|
|
auto inferred_status_error5 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
|
|
{vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error5.ok());
|
|
EXPECT_THAT(inferred_status_error5.status().error_message(),
|
|
HasSubstr("the result of branch 0 computation and branch 1 "
|
|
"computation must have the same shape"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, ConditionalIndexed) {
|
|
auto r0s32 = ShapeUtil::MakeShape(S32, {});
|
|
auto inferred_status0 = ShapeInference::InferConditionalShape(
|
|
r0s32,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
|
|
{vector_32_, vector_64_, vector_64_});
|
|
EXPECT_IS_OK(inferred_status0.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie()));
|
|
|
|
auto inferred_status1 = ShapeInference::InferConditionalShape(
|
|
r0s32,
|
|
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, vector_64_),
|
|
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_)},
|
|
{matrix_32_48_, vector_32_, matrix_32_48_});
|
|
EXPECT_IS_OK(inferred_status1.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie()));
|
|
|
|
auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_});
|
|
auto inferred_status2 = ShapeInference::InferConditionalShape(
|
|
r0s32, {ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)},
|
|
{tuple_f32_v32});
|
|
EXPECT_IS_OK(inferred_status2.status());
|
|
EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie()));
|
|
|
|
auto inferred_status_error0 = ShapeInference::InferConditionalShape(
|
|
pred_,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, f32_)},
|
|
{vector_32_, vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error0.ok());
|
|
EXPECT_THAT(inferred_status_error0.status().error_message(),
|
|
HasSubstr("2 == branch_computations.size()"));
|
|
|
|
auto inferred_status_error1 = ShapeInference::InferConditionalShape(
|
|
r0s32,
|
|
{ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_),
|
|
ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_),
|
|
ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)},
|
|
{matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}),
|
|
matrix_32_48_});
|
|
EXPECT_FALSE(inferred_status_error1.ok());
|
|
EXPECT_THAT(inferred_status_error1.status().error_message(),
|
|
HasSubstr("branch computation 1 must take 1 argument"));
|
|
|
|
auto inferred_status_error2 = ShapeInference::InferConditionalShape(
|
|
r0s32,
|
|
{ShapeUtil::MakeProgramShape({r0s32}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, f32_)},
|
|
{r0s32, vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error2.ok());
|
|
EXPECT_THAT(inferred_status_error2.status().error_message(),
|
|
HasSubstr("branch operand 2 must match the shape of the only "
|
|
"parameter of branch computation 2"));
|
|
|
|
auto inferred_status_error3 = ShapeInference::InferConditionalShape(
|
|
r0s32,
|
|
{ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_32_}, f32_),
|
|
ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)},
|
|
{vector_32_, vector_32_, vector_32_, vector_64_});
|
|
EXPECT_FALSE(inferred_status_error3.ok());
|
|
EXPECT_THAT(inferred_status_error3.status().error_message(),
|
|
HasSubstr("the result of branch 0 computation and branch 3 "
|
|
"computation must have the same shape"));
|
|
|
|
auto inferred_status_error4 =
|
|
ShapeInference::InferConditionalShape(r0s32, {}, {});
|
|
EXPECT_FALSE(inferred_status_error4.ok());
|
|
EXPECT_THAT(inferred_status_error4.status().error_message(),
|
|
HasSubstr("!branch_computations.empty()"));
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BadSlice) {
|
|
auto arg = ShapeUtil::MakeShape(F32, {4});
|
|
StatusOr<Shape> statusor =
|
|
ShapeInference::InferSliceShape(arg, {0}, {5}, {1});
|
|
ASSERT_FALSE(statusor.ok());
|
|
|
|
LOG(INFO) << statusor.status();
|
|
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("less than or equal to dimension size"))
|
|
<< statusor.status();
|
|
EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BadSort) {
|
|
auto keys = ShapeUtil::MakeShape(F32, {4});
|
|
auto values = ShapeUtil::MakeShape(F32, {5});
|
|
StatusOr<Shape> statusor =
|
|
ShapeInference::InferVariadicOpShape(HloOpcode::kSort, {&keys, &values});
|
|
EXPECT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("dimensions must match"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, BadSortValuesMismatch) {
|
|
auto keys = ShapeUtil::MakeShape(F32, {4});
|
|
auto values_good = ShapeUtil::MakeShape(F32, {4});
|
|
auto values_bad = ShapeUtil::MakeShape(F32, {5});
|
|
StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
|
|
HloOpcode::kSort, {&keys, &values_good, &values_bad});
|
|
EXPECT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("dimensions must match"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ShapeInferenceTest, SortManyValues) {
|
|
auto keys = ShapeUtil::MakeShape(F32, {4});
|
|
auto values_s32 = ShapeUtil::MakeShape(S32, {4});
|
|
auto values_u32 = ShapeUtil::MakeShape(U32, {4});
|
|
StatusOr<Shape> statusor = ShapeInference::InferVariadicOpShape(
|
|
HloOpcode::kSort, {&keys, &values_s32, &values_u32});
|
|
EXPECT_IS_OK(statusor);
|
|
Shape inferred_shape = statusor.ValueOrDie();
|
|
EXPECT_TRUE(ShapeUtil::Compatible(
|
|
inferred_shape,
|
|
ShapeUtil::MakeTupleShape({keys, values_s32, values_u32})));
|
|
}
|
|
|
|
class ScatterGatherShapeInferenceTest : public ShapeInferenceTest {
|
|
protected:
|
|
const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {});
|
|
const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5});
|
|
const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32});
|
|
const Shape s64_4d_tensor_10_9_8_7_1_ =
|
|
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1});
|
|
const Shape s64_4d_tensor_10_9_8_7_5_ =
|
|
ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5});
|
|
const Shape s64_4d_tensor_5_10_9_7_6_ =
|
|
ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6});
|
|
const Shape s64_4d_tensor_10_9_5_7_6_ =
|
|
ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6});
|
|
const Shape f32_5d_tensor_50_49_48_47_46_ =
|
|
ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46});
|
|
const Shape tuple_shape_ = ShapeUtil::MakeTupleShape(
|
|
{s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_});
|
|
const ProgramShape to_apply_ =
|
|
ShapeUtil::MakeProgramShape({f32_, f32_}, f32_);
|
|
};
|
|
|
|
// Shape inference tests for Gather.
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGather) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
matrix_64_48_, s64_vector_32_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0},
|
|
/*collapsed_slice_dims=*/{1},
|
|
/*start_index_map=*/{1},
|
|
/*index_vector_dim=*/1),
|
|
/*slice_sizes=*/{64, 1}));
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherV2) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
matrix_64_48_, s64_vector_32_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{1},
|
|
/*collapsed_slice_dims=*/{0},
|
|
/*start_index_map=*/{0},
|
|
/*index_vector_dim=*/1),
|
|
/*slice_sizes=*/{1, 48}));
|
|
EXPECT_TRUE(
|
|
ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TensorFlowGatherNd) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4},
|
|
/*collapsed_slice_dims=*/{0},
|
|
/*start_index_map=*/{0},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{1, 48}));
|
|
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TensorFlowBatchDynamicSlice) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26}));
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
gather_shape,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherEntireDimension) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
ShapeUtil::MakeShape(F32, {3, 2, 1}, {false, true, false}),
|
|
ShapeUtil::MakeShape(S64, {}),
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0, 1},
|
|
/*collapsed_slice_dims=*/{0},
|
|
/*start_index_map=*/{0},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{1, 2, 1}));
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {true, false})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, DynamicGatherCollapsedDimension) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
ShapeUtil::MakeShape(F32, {3, 2, 1}, {true, false, false}),
|
|
ShapeUtil::MakeShape(S64, {}),
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0, 1},
|
|
/*collapsed_slice_dims=*/{0},
|
|
/*start_index_map=*/{0},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{1, 2, 1}));
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
gather_shape, ShapeUtil::MakeShape(F32, {2, 1}, {false, false})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, DynamicIndices) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
ShapeUtil::MakeShape(F32, {3, 2, 2}),
|
|
ShapeUtil::MakeShape(S64, {3, 4, 2}, {false, true, false}),
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{2, 3},
|
|
/*collapsed_slice_dims=*/{0},
|
|
/*start_index_map=*/{0, 1},
|
|
/*index_vector_dim=*/2),
|
|
/*slice_sizes=*/{1, 2, 2}));
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
gather_shape,
|
|
ShapeUtil::MakeShape(F32, {3, 4, 2, 2}, {false, true, false, false})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/2),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26}));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
gather_shape,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26}));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(
|
|
gather_shape,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, NoOutputGatherDims) {
|
|
// This is equivalent to a dynamic slice.
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0, 1, 2, 3, 4},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26}));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
|
|
ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, ScalarGatherIndices) {
|
|
// The gather indices "tensor" is a scalar S here that's used to slice out
|
|
// [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result.
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape,
|
|
ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0, 1, 2, 3},
|
|
/*collapsed_slice_dims=*/{0},
|
|
/*start_index_map=*/{0},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{1, 30, 29, 28, 27}));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(gather_shape,
|
|
ShapeUtil::MakeShape(F32, {30, 29, 28, 27})))
|
|
<< ShapeUtil::HumanString(gather_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TupleShapedTensorInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
tuple_shape_, s64_vector_32_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0},
|
|
/*collapsed_slice_dims=*/{1},
|
|
/*start_index_map=*/{1},
|
|
/*index_vector_dim=*/1),
|
|
/*slice_sizes=*/{64, 1});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Expected array argument for input"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TupleShapedGatherIndicesInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
s64_vector_32_, tuple_shape_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0},
|
|
/*collapsed_slice_dims=*/{1},
|
|
/*start_index_map=*/{1},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{64, 1});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Expected array argument for gather indices"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, FloatingPointGatherIndicesInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
s64_vector_32_, vector_32_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{0},
|
|
/*collapsed_slice_dims=*/{1},
|
|
/*start_index_map=*/{1},
|
|
/*index_vector_dim=*/0),
|
|
/*slice_sizes=*/{64, 1});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Gather indices parameter must be an integral tensor"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_NonAscendingWindowIndices) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 8, 7},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Output window dimensions in gather op must be ascending"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_RepeatedWindowIndices) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 7},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Output window dimensions in gather op must not repeat"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_WindowIndexOutOfBounds) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 99, 100, 101},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Offset dimension 2 in gather op is out of bounds"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 9},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Offset dimension 4 in gather op is out of bounds"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_MismatchingElidedWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{4},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("All components of the offset index in a gather op must either "
|
|
"be a offset dimension or explicitly collapsed"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{0, 1, 2, 3, 19},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Invalid collapsed_slice_dims set in gather op; valid "
|
|
"range is [0, 5), got: 19"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_RepeatedWindowToInputMapping) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{0, 1, 2, 3, 3},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Repeated dimensions not allowed in "
|
|
"collapsed_slice_dims in gather op"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_MismatchingGatherToInputMapping) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Gather op has 4 elements in start_index_map and "
|
|
"the bound of dimension index_vector_dim=4 of "
|
|
"start_indices is 5. These two numbers must be equal."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 7},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Invalid start_index_map; domain is [0, 5), got: 4->7"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_RepeatedGatherToInputMapping) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 3},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Repeated dimensions are not allowed in start_index_map"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_NonAscendingElidedWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{2, 1},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{1, 1, 28, 27, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("collapsed_slice_dims in gather op must be sorted"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_WindowBoundsTooLarge) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7},
|
|
/*collapsed_slice_dims=*/{2},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 1, 300, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Slice size at index 3 in gather op is out of range, "
|
|
"must be within [0, 48), got 300."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 26});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Gather op must have one slice size for every input dimension"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7},
|
|
/*collapsed_slice_dims=*/{1},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4),
|
|
/*slice_sizes=*/{30, 29, 28, 26, 20});
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Gather op can only collapse slice dims with bound 1 or 0, "
|
|
"but bound is 29 for index 1 at position 0."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferGatherShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
|
|
HloGatherInstruction::MakeGatherDimNumbers(
|
|
/*offset_dims=*/{4, 5, 6, 7, 8},
|
|
/*collapsed_slice_dims=*/{},
|
|
/*start_index_map=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/32),
|
|
/*slice_sizes=*/{30, 29, 28, 27, 26});
|
|
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Gather index leaf dimension must be within [0, "
|
|
"rank(start_indices) + 1)"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
// Shape inference tests for Scatter.
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdates) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_,
|
|
ShapeUtil::MakeShape(F32, {64, 32}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithFullUpdatesV2) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_,
|
|
ShapeUtil::MakeShape(F32, {32, 48}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{1},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/1)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdates) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_,
|
|
ShapeUtil::MakeShape(F32, {10, 32}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithPartialUpdatesV2) {
|
|
TF_ASSERT_OK_AND_ASSIGN(Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_,
|
|
ShapeUtil::MakeShape(F32, {32, 8}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{1},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/1)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {65, 32}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Bounds of the window dimensions of updates must not exceed "
|
|
"the bounds of the corresponding dimensions of operand."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterWithUpdatesBiggerThanInputV2) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {32, 49}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{1},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Bounds of the window dimensions of updates must not exceed "
|
|
"the bounds of the corresponding dimensions of operand."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
TfScatterWithUpdatesNotMatchingIndices) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {64, 31}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr(
|
|
"Bounds of the scatter dimensions of updates must be same as the "
|
|
"bounds of the corresponding dimensions of scatter indices."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
TfScatterWithUpdatesNotMatchingIndicesV2) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_vector_32_, ShapeUtil::MakeShape(F32, {31, 48}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{1},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr(
|
|
"Bounds of the scatter dimensions of updates must be same as the "
|
|
"bounds of the corresponding dimensions of scatter indices."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdates) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/4)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithFullUpdatesV2) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 64}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/4)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdates) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 10}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/4)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithPartialUpdatesV2) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 12}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/4)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, matrix_64_48_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfScatterNdWithUpdatesBiggerThanInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 65}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Bounds of the window dimensions of updates must not exceed "
|
|
"the bounds of the corresponding dimensions of operand."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
TfScatterNdWithUpdatesNotMatchingIndices) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
matrix_64_48_, s64_4d_tensor_10_9_8_7_1_,
|
|
ShapeUtil::MakeShape(F32, {9, 9, 8, 7, 64}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr(
|
|
"Bounds of the scatter dimensions of updates must be same as the "
|
|
"bounds of the corresponding dimensions of scatter indices."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, TfBatchDynamicUpdateSlice) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6, 7, 8},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4)));
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDim) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6, 7, 8},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/2)));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, NonDefaultScatterIndicesLeafDimV2) {
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}),
|
|
to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6, 7, 8},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/0)));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, NoUpdateScatterDims) {
|
|
// This is equivalent to a dynamic update slice.
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_vector_5_,
|
|
ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0, 1, 2, 3, 4},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/0)));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, ScalarScatterIndices) {
|
|
// The scalar indices "tensor" is a scalar S here that's used to update a
|
|
// [30,29,28,27] shaped tensor within the operand at position S.
|
|
TF_ASSERT_OK_AND_ASSIGN(
|
|
Shape scatter_shape,
|
|
ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
|
|
ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0, 1, 2, 3},
|
|
/*inserted_window_dims=*/{0},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/0)));
|
|
|
|
EXPECT_TRUE(ShapeUtil::Equal(scatter_shape, f32_5d_tensor_50_49_48_47_46_))
|
|
<< ShapeUtil::HumanString(scatter_shape);
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedTensorInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
tuple_shape_, s64_vector_32_, s64_vector_32_, to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/1));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Expected array argument for operand"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
ScatterWithTupleShapedScatterIndicesInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
s64_vector_32_, tuple_shape_, s64_vector_32_, to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/0));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Expected array argument for scatter indices"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, ScatterWithTupleShapedUpdatesInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
s64_vector_32_, s64_vector_32_, tuple_shape_, to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/0));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Expected array argument for updates"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, FloatingPointScatterIndicesInput) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
s64_vector_32_, vector_32_, s64_vector_32_, to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0},
|
|
/*inserted_window_dims=*/{1},
|
|
/*scatter_dims_to_operand_dims=*/{1},
|
|
/*index_vector_dim=*/0));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Scatter indices parameter must be an integral tensor"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, OutOfBoundsScatterIndicesLeafDim) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 2},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/10));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Scatter index leaf dimension must be within [0, "
|
|
"rank(scatter_indices) + 1)"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdates) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 50}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 2},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Updates tensor must be of rank 7; got 8."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest, InvalidUpdateComputation) {
|
|
const ProgramShape invalid_update_computation =
|
|
ShapeUtil::MakeProgramShape({f32_}, f32_);
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}),
|
|
invalid_update_computation,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 2},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Reduction function must take 2 parameters, but takes 1"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_NonAscendingUpdateWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6, 8, 7},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("update_window_dims in scatter op must be sorted"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_RepeatedUpdateWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6, 7, 7},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("update_window_dims in scatter op must not repeat"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_OutOfBoundsUpdateWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6, 7, 9},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Invalid update_window_dims set in scatter op; valid "
|
|
"range is [0, 9)"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_NonAscendingInsertedWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{2, 1},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("inserted_window_dims in scatter op must be sorted"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_RepeatedInsertedWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 1},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("inserted_window_dims in scatter op must not repeat"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_OutOfBoundsInsertedWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 5},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 4},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Invalid inserted_window_dims set in scatter op; valid "
|
|
"range is [0, 5)"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_MismatchingScatterDimsToOperandDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 2},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr("Scatter op has 4 elements in scatter_dims_to_operand_dims and "
|
|
"the bound of dimension index_vector_dim=4 of scatter_indices "
|
|
"is 5. These two numbers must be equal"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_OutOfBoundsScatterDimsToOperandDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 2},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 3, 10},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(statusor.status().error_message(),
|
|
HasSubstr("Invalid scatter_dims_to_operand_dims mapping; domain "
|
|
"is [0, 5), got: 4->10"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_RepeatedValuesInScatterDimsToOperandDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_,
|
|
ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{4, 5, 6},
|
|
/*inserted_window_dims=*/{1, 2},
|
|
/*scatter_dims_to_operand_dims=*/{0, 1, 2, 2, 3},
|
|
/*index_vector_dim=*/4));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr(
|
|
"Repeated dimensions not allowed in scatter_dims_to_operand_dims"))
|
|
<< statusor.status();
|
|
}
|
|
|
|
TEST_F(ScatterGatherShapeInferenceTest,
|
|
InvalidScatterDimNumbers_InsufficientWindowDims) {
|
|
StatusOr<Shape> statusor = ShapeInference::InferScatterShape(
|
|
f32_5d_tensor_50_49_48_47_46_, s64_scalar_,
|
|
ShapeUtil::MakeShape(F32, {30, 29, 28, 27}), to_apply_,
|
|
HloScatterInstruction::MakeScatterDimNumbers(
|
|
/*update_window_dims=*/{0, 1, 2, 3},
|
|
/*inserted_window_dims=*/{},
|
|
/*scatter_dims_to_operand_dims=*/{0},
|
|
/*index_vector_dim=*/0));
|
|
ASSERT_FALSE(statusor.ok());
|
|
EXPECT_THAT(
|
|
statusor.status().error_message(),
|
|
HasSubstr(
|
|
"Scatter op has window of size 4; doesn't match operand of rank 5."))
|
|
<< statusor.status();
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace xla
|