Add a test for api_implements exist but preferred_device not exist.
PiperOrigin-RevId: 315815796 Change-Id: I67c1c1f3b9b409a5ade455a67debeb1d2274a27b
This commit is contained in:
parent
d0d096f5fb
commit
cc41b2735c
@ -220,6 +220,62 @@ TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ImplementationSelectorTest, NoSwapWithImplementsOnly) {
|
||||||
|
using test::function::NDef;
|
||||||
|
ImplementationSelector optimizer;
|
||||||
|
GraphDef output;
|
||||||
|
GrapplerItem item;
|
||||||
|
// DeviceIndex op based implementation selector.
|
||||||
|
AttrValue device_names;
|
||||||
|
device_names.mutable_list()->add_s("CPU");
|
||||||
|
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||||
|
device_names.mutable_list()->add_s("GPU");
|
||||||
|
|
||||||
|
// Api_implements exists, api_preferred_device does not, no swap.
|
||||||
|
auto cpu_def = test::function::XTimesTwo();
|
||||||
|
auto* func_attr = cpu_def.mutable_attr();
|
||||||
|
(*func_attr)["api_implements"].set_s("times_two");
|
||||||
|
|
||||||
|
auto gpu_def = test::function::XAddX();
|
||||||
|
auto* func2_attr = gpu_def.mutable_attr();
|
||||||
|
(*func2_attr)["api_implements"].set_s("times_two");
|
||||||
|
|
||||||
|
item.graph = test::function::GDef(
|
||||||
|
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
CpuDevice),
|
||||||
|
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||||
|
GpuDevice),
|
||||||
|
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
|
||||||
|
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||||
|
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
|
||||||
|
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
|
||||||
|
// FunctionLib
|
||||||
|
{cpu_def, gpu_def});
|
||||||
|
|
||||||
|
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||||
|
for (const NodeDef& node : output.node()) {
|
||||||
|
if (node.name() == "x") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
if (node.name() == "y") {
|
||||||
|
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||||
|
}
|
||||||
|
if (node.name() == "y1") {
|
||||||
|
// api_implements only, no preferred device, no swap.
|
||||||
|
EXPECT_EQ("XTimesTwo", node.op());
|
||||||
|
} else if (node.name() == "y2") {
|
||||||
|
// Make sure the implementation is not changed.
|
||||||
|
EXPECT_EQ("XTimesTwo", node.op());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
||||||
using test::function::NDef;
|
using test::function::NDef;
|
||||||
auto cpu_def = test::function::XTimesTwo();
|
auto cpu_def = test::function::XTimesTwo();
|
||||||
|
Loading…
Reference in New Issue
Block a user