Add tests for dynamic parameter binding's serialization and deserialization behaviors.

PiperOrigin-RevId: 230418554
This commit is contained in:
Yunxing Dai 2019-01-22 15:06:11 -08:00 committed by TensorFlower Gardener
parent f7d933a71f
commit 29970780c7
3 changed files with 66 additions and 38 deletions

View File

@ -29,7 +29,8 @@ Status DynamicParameterBinding::Bind(
}
absl::optional<DynamicParameterBinding::DynamicParameter>
DynamicParameterBinding::GetBinding(const DynamicDimension& dynamic_dimension) {
DynamicParameterBinding::GetBinding(
const DynamicDimension& dynamic_dimension) const {
auto param_iter = bindings_.find(dynamic_dimension);
if (param_iter == bindings_.end()) {
return absl::nullopt;

View File

@ -89,7 +89,7 @@ class DynamicParameterBinding {
//
// Returns nullopt if the binding is not set.
absl::optional<DynamicParameter> GetBinding(
const DynamicDimension& dynamic_dimension);
const DynamicDimension& dynamic_dimension) const;
using BindingFn =
std::function<Status(const DynamicParameter& dynamic_parameter,

View File

@ -33,7 +33,15 @@ limitations under the License.
namespace xla {
namespace {
class DynamicParameterBindingTest : public HloTestBase {};
class DynamicParameterBindingTest : public HloTestBase {
protected:
// Serialize and then deserialize a binding.
void SerializeAndDeserialize(DynamicParameterBinding* binding) {
DynamicParameterBindingProto proto = binding->ToProto();
TF_ASSERT_OK_AND_ASSIGN(*binding,
DynamicParameterBinding::CreateFromProto(proto));
}
};
TEST_F(DynamicParameterBindingTest, SimpleBinding) {
// 'b' is a dynamic shape; 'a' represents the real size of b's first
@ -56,15 +64,20 @@ ENTRY main {
binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}},
DynamicParameterBinding::DynamicDimension{1, {}, 0}));
absl::optional<DynamicParameterBinding::DynamicParameter> param =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
/*parameter_index=*/{},
/*dimension=*/0});
EXPECT_TRUE(param);
EXPECT_EQ(param->parameter_num, 0);
EXPECT_EQ(param->parameter_index, ShapeIndex({}));
TF_EXPECT_OK(binding.Verify(*module));
auto test = [&](const DynamicParameterBinding& binding) {
absl::optional<DynamicParameterBinding::DynamicParameter> param =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
/*parameter_index=*/{},
/*dimension=*/0});
EXPECT_TRUE(param);
EXPECT_EQ(param->parameter_num, 0);
EXPECT_EQ(param->parameter_index, ShapeIndex({}));
TF_EXPECT_OK(binding.Verify(*module));
};
test(binding);
SerializeAndDeserialize(&binding);
test(binding);
}
TEST_F(DynamicParameterBindingTest, TupleBinding) {
@ -89,16 +102,21 @@ ENTRY main {
binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
absl::optional<DynamicParameterBinding::DynamicParameter> param =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
/*parameter_index=*/{1},
/*dimension=*/0});
auto test = [&](const DynamicParameterBinding& binding) {
absl::optional<DynamicParameterBinding::DynamicParameter> param =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
/*parameter_index=*/{1},
/*dimension=*/0});
EXPECT_TRUE(param);
EXPECT_EQ(param->parameter_num, 0);
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
TF_EXPECT_OK(binding.Verify(*module));
EXPECT_TRUE(param);
EXPECT_EQ(param->parameter_num, 0);
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
TF_EXPECT_OK(binding.Verify(*module));
};
test(binding);
SerializeAndDeserialize(&binding);
test(binding);
}
TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) {
@ -127,26 +145,35 @@ ENTRY main {
binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
DynamicParameterBinding::DynamicDimension{0, {1}, 1}));
absl::optional<DynamicParameterBinding::DynamicParameter> param =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
/*parameter_index=*/{1},
/*dimension=*/0});
auto test = [&](const DynamicParameterBinding& binding) {
absl::optional<DynamicParameterBinding::DynamicParameter> param =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
/*parameter_index=*/{1},
/*dimension=*/0});
EXPECT_TRUE(param);
EXPECT_EQ(param->parameter_num, 0);
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
EXPECT_TRUE(param);
EXPECT_EQ(param->parameter_num, 0);
EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
absl::optional<DynamicParameterBinding::DynamicParameter> param2 =
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
/*parameter_index=*/{1},
/*dimension=*/0});
EXPECT_TRUE(param2);
EXPECT_EQ(param2->parameter_num, 0);
EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
absl::optional<DynamicParameterBinding::DynamicParameter> param2 =
TF_EXPECT_OK(binding.Verify(*module));
binding.GetBinding(
DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
/*parameter_index=*/{1},
/*dimension=*/0});
EXPECT_TRUE(param2);
EXPECT_EQ(param2->parameter_num, 0);
EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
TF_EXPECT_OK(binding.Verify(*module));
};
test(binding);
SerializeAndDeserialize(&binding);
// Test the binding again after deserialization.
test(binding);
}
} // namespace