Handle the case where api_implements exists but api_prefered_device does not exist.
PiperOrigin-RevId: 315743103 Change-Id: I085c871f7df7713daacebd595a1aa730b5dcda58
This commit is contained in:
parent
e66dabe79d
commit
65a2d4656a
@ -59,13 +59,6 @@ Status FunctionApiInfo::Init(const FunctionDef& function_def) {
|
||||
"Function '", function_def.signature().name(),
|
||||
"' has a preferred device, but does not implement an interface");
|
||||
}
|
||||
// Handles the case that api_implements exists but prefered_device does not
|
||||
// exist. Currently this is for tf lite/mlir, which depends on api_implements.
|
||||
if (!interface_name_.empty() && preferred_device_.empty()) {
|
||||
VLOG(1) << "A function has api_implements: " << interface_name_ << ", but "
|
||||
<< "api_preferred_device";
|
||||
interface_name_.clear();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -138,9 +138,11 @@ TEST(FunctionApiInfoTest, ParseTags) {
|
||||
|
||||
EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
|
||||
EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
|
||||
EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
|
||||
|
||||
EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
|
||||
EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
|
||||
EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
|
||||
|
||||
EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
|
||||
EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
|
||||
@ -183,19 +185,6 @@ TEST(FunctionApiInfoTest, MismatchedArguments) {
|
||||
EXPECT_FALSE(ret.ok());
|
||||
}
|
||||
|
||||
TEST(FunctionApiInfoTest, ImplementsWithoutDevice) {
|
||||
FunctionDefLibrary func_lib;
|
||||
const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
|
||||
const std::vector<ArgSpec> output_args{{"out", "float32"}};
|
||||
PopulateFunction("DoThings", "DoThings", "", func_args, output_args, "", "",
|
||||
func_lib.add_function());
|
||||
FunctionLibraryApiInfo lib_api_info;
|
||||
const Status ret = lib_api_info.Init(func_lib);
|
||||
EXPECT_TRUE(ret.ok());
|
||||
EXPECT_TRUE(lib_api_info.empty());
|
||||
EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -220,62 +220,6 @@ TEST_F(ImplementationSelectorTest, TwoTypesOfSwapImplementation) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, NoSwapWithImplementsOnly) {
|
||||
using test::function::NDef;
|
||||
ImplementationSelector optimizer;
|
||||
GraphDef output;
|
||||
GrapplerItem item;
|
||||
// DeviceIndex op based implementation selector.
|
||||
AttrValue device_names;
|
||||
device_names.mutable_list()->add_s("CPU");
|
||||
device_names.mutable_list()->add_s("TPU_REPLICATED_CORE");
|
||||
device_names.mutable_list()->add_s("GPU");
|
||||
|
||||
// Api_implements exists, api_preferred_device does not, no swap.
|
||||
auto cpu_def = test::function::XTimesTwo();
|
||||
auto* func_attr = cpu_def.mutable_attr();
|
||||
(*func_attr)["api_implements"].set_s("times_two");
|
||||
|
||||
auto gpu_def = test::function::XAddX();
|
||||
auto* func2_attr = gpu_def.mutable_attr();
|
||||
(*func2_attr)["api_implements"].set_s("times_two");
|
||||
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("x", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
CpuDevice),
|
||||
NDef("case", "Case", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y", "DeviceIndex", {}, {{"device_names", device_names}},
|
||||
GpuDevice),
|
||||
NDef("case_y", "Case", {"y"}, {{"T", DT_FLOAT}}, TpuDevice),
|
||||
NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
|
||||
NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
|
||||
NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
|
||||
// FunctionLib
|
||||
{cpu_def, gpu_def});
|
||||
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
|
||||
for (const NodeDef& node : output.node()) {
|
||||
if (node.name() == "x") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of GPU index 1.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(2, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y") {
|
||||
// Rewrite DeviceIndex op to a Const op with value of CPU index 0.
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(1, node.attr().at("value").tensor().int_val(0));
|
||||
}
|
||||
if (node.name() == "y1") {
|
||||
// api_implements only, no preferred device, no swap.
|
||||
EXPECT_EQ("XTimesTwo", node.op());
|
||||
} else if (node.name() == "y2") {
|
||||
// Make sure the implementation is not changed.
|
||||
EXPECT_EQ("XTimesTwo", node.op());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(ImplementationSelectorTest, SwapImplementation) {
|
||||
using test::function::NDef;
|
||||
auto cpu_def = test::function::XTimesTwo();
|
||||
|
Loading…
Reference in New Issue
Block a user