Add tests for dynamic parameter binding's serialization and deserialization behaviors.
PiperOrigin-RevId: 230418554
This commit is contained in:
parent
f7d933a71f
commit
29970780c7
@ -29,7 +29,8 @@ Status DynamicParameterBinding::Bind(
|
||||
}
|
||||
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter>
|
||||
DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) {
|
||||
DynamicParameterBinding::GetBinding(
|
||||
const DynamicDimension& dynamic_dimension) const {
|
||||
auto param_iter = bindings_.find(dynamic_dimension);
|
||||
if (param_iter == bindings_.end()) {
|
||||
return absl::nullopt;
|
||||
|
@ -89,7 +89,7 @@ class DynamicParameterBinding {
|
||||
//
|
||||
// Returns nullopt if the binding is not set.
|
||||
absl::optional<DynamicParameter> GetBinding(
|
||||
const DynamicDimension& dynamic_dimension);
|
||||
const DynamicDimension& dynamic_dimension) const;
|
||||
|
||||
using BindingFn =
|
||||
std::function<Status(const DynamicParameter& dynamic_parameter,
|
||||
|
@ -33,7 +33,15 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
class DynamicParameterBindingTest : public HloTestBase {};
|
||||
class DynamicParameterBindingTest : public HloTestBase {
|
||||
protected:
|
||||
// Serialize and then deserialize a binding.
|
||||
void SerializeAndDeserialize(DynamicParameterBinding* binding) {
|
||||
DynamicParameterBindingProto proto = binding->ToProto();
|
||||
TF_ASSERT_OK_AND_ASSIGN(*binding,
|
||||
DynamicParameterBinding::CreateFromProto(proto));
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(DynamicParameterBindingTest, SimpleBinding) {
|
||||
// 'b' is a dynamic shape; 'a' represents the real size of b's first
|
||||
@ -56,15 +64,20 @@ ENTRY main {
|
||||
binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}},
|
||||
DynamicParameterBinding::DynamicDimension{1, {}, 0}));
|
||||
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
|
||||
/*parameter_index=*/{},
|
||||
/*dimension=*/0});
|
||||
EXPECT_TRUE(param);
|
||||
EXPECT_EQ(param->parameter_num, 0);
|
||||
EXPECT_EQ(param->parameter_index, ShapeIndex({}));
|
||||
TF_EXPECT_OK(binding.Verify(*module));
|
||||
auto test = [&](const DynamicParameterBinding& binding) {
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
|
||||
/*parameter_index=*/{},
|
||||
/*dimension=*/0});
|
||||
EXPECT_TRUE(param);
|
||||
EXPECT_EQ(param->parameter_num, 0);
|
||||
EXPECT_EQ(param->parameter_index, ShapeIndex({}));
|
||||
TF_EXPECT_OK(binding.Verify(*module));
|
||||
};
|
||||
test(binding);
|
||||
SerializeAndDeserialize(&binding);
|
||||
test(binding);
|
||||
}
|
||||
|
||||
TEST_F(DynamicParameterBindingTest, TupleBinding) {
|
||||
@ -89,16 +102,21 @@ ENTRY main {
|
||||
binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
|
||||
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
|
||||
/*parameter_index=*/{1},
|
||||
/*dimension=*/0});
|
||||
auto test = [&](const DynamicParameterBinding& binding) {
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
|
||||
/*parameter_index=*/{1},
|
||||
/*dimension=*/0});
|
||||
|
||||
EXPECT_TRUE(param);
|
||||
EXPECT_EQ(param->parameter_num, 0);
|
||||
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
|
||||
TF_EXPECT_OK(binding.Verify(*module));
|
||||
EXPECT_TRUE(param);
|
||||
EXPECT_EQ(param->parameter_num, 0);
|
||||
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
|
||||
TF_EXPECT_OK(binding.Verify(*module));
|
||||
};
|
||||
test(binding);
|
||||
SerializeAndDeserialize(&binding);
|
||||
test(binding);
|
||||
}
|
||||
|
||||
TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) {
|
||||
@ -127,26 +145,35 @@ ENTRY main {
|
||||
binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
|
||||
DynamicParameterBinding::DynamicDimension{0, {1}, 1}));
|
||||
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
|
||||
/*parameter_index=*/{1},
|
||||
/*dimension=*/0});
|
||||
auto test = [&](const DynamicParameterBinding& binding) {
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
|
||||
/*parameter_index=*/{1},
|
||||
/*dimension=*/0});
|
||||
|
||||
EXPECT_TRUE(param);
|
||||
EXPECT_EQ(param->parameter_num, 0);
|
||||
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
|
||||
EXPECT_TRUE(param);
|
||||
EXPECT_EQ(param->parameter_num, 0);
|
||||
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
|
||||
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param2 =
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
|
||||
/*parameter_index=*/{1},
|
||||
/*dimension=*/0});
|
||||
EXPECT_TRUE(param2);
|
||||
EXPECT_EQ(param2->parameter_num, 0);
|
||||
EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
|
||||
absl::optional<DynamicParameterBinding::DynamicParameter> param2 =
|
||||
|
||||
TF_EXPECT_OK(binding.Verify(*module));
|
||||
binding.GetBinding(
|
||||
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
|
||||
/*parameter_index=*/{1},
|
||||
/*dimension=*/0});
|
||||
EXPECT_TRUE(param2);
|
||||
EXPECT_EQ(param2->parameter_num, 0);
|
||||
EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
|
||||
TF_EXPECT_OK(binding.Verify(*module));
|
||||
};
|
||||
|
||||
test(binding);
|
||||
|
||||
SerializeAndDeserialize(&binding);
|
||||
|
||||
// Test the binding again after deserialization.
|
||||
test(binding);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
Loading…
x
Reference in New Issue
Block a user