Refactoring and test coverage improvements
PiperOrigin-RevId: 338714669 Change-Id: Iffb1c652be6a3c44b9ecdafc89696231fc2ae904
This commit is contained in:
parent
2c787e0028
commit
66b124ef0e
@ -76,8 +76,8 @@ class DatasetBase;
|
||||
class SerializationContext;
|
||||
|
||||
inline bool IsTFDataFunction(const FunctionDef& func) {
|
||||
return (func.attr().contains(data::kTFDataFunction) &&
|
||||
func.attr().at(data::kTFDataFunction).b());
|
||||
auto iter = func.attr().find(data::kTFDataFunction);
|
||||
return (iter != func.attr().end() && iter->second.b());
|
||||
}
|
||||
|
||||
// Interface for reading values from a key-value store.
|
||||
|
@ -615,9 +615,13 @@ Status MetaOptimizer::RunOptimizer(
|
||||
void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
|
||||
FunctionDefLibrary& fdef_lib) {
|
||||
// Collect functions that need the attribute in this set.
|
||||
absl::flat_hash_set<string> tf_data_functions;
|
||||
std::function<void(const string&)> collect_tf_data_functions_dfs =
|
||||
[&](const string& func_name) -> void {
|
||||
absl::flat_hash_set<std::string> tf_data_functions;
|
||||
std::function<void(const std::string&)> collect_tf_data_functions_dfs =
|
||||
[&](const std::string& func_name) -> void {
|
||||
const FunctionDef* func_def = flib.Find(func_name);
|
||||
// Skip functions that are not reachable from the optimized graph.
|
||||
if (func_def == nullptr) return;
|
||||
|
||||
// Return if we already found and added this function.
|
||||
if (tf_data_functions.contains(func_name)) return;
|
||||
|
||||
@ -625,10 +629,6 @@ void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
|
||||
// a tf.data function, so add it to the set.
|
||||
tf_data_functions.insert(func_name);
|
||||
|
||||
const FunctionDef* func_def = flib.Find(func_name);
|
||||
// Skip functions that are not reachable from the optimized graph.
|
||||
if (func_def == nullptr) return;
|
||||
|
||||
// Proceed with DFS for functions called from current function.
|
||||
for (const NodeDef& node : func_def->node_def()) {
|
||||
if (flib.Contains(node.op())) {
|
||||
@ -651,7 +651,7 @@ void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
|
||||
};
|
||||
// Perform DFS for all tf.data functions in `fdef_lib`.
|
||||
for (const auto& func_def : fdef_lib.function()) {
|
||||
const string& func_name = func_def.signature().name();
|
||||
const std::string& func_name = func_def.signature().name();
|
||||
if (data::IsTFDataFunction(func_def))
|
||||
collect_tf_data_functions_dfs(func_name);
|
||||
}
|
||||
@ -659,7 +659,7 @@ void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
|
||||
// because `FunctionLibraryDefinition` does not seem to provide mutable access
|
||||
// to a `FunctionDef`.
|
||||
for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
|
||||
const string& func_name = func_def.signature().name();
|
||||
const std::string& func_name = func_def.signature().name();
|
||||
if (tf_data_functions.contains(func_name) &&
|
||||
!data::IsTFDataFunction(func_def)) {
|
||||
VLOG(2) << "Marking " << func_name << " as tf.data function";
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
@ -1020,14 +1022,19 @@ TEST_F(MetaOptimizerTest, CompressConstants) {
|
||||
// Tests for checking expected behavior when skipping tf.data functions in
|
||||
// meta optimizer.
|
||||
|
||||
// Custom optimizer which counts its calls.
|
||||
// Custom optimizer which counts the number of calls of its method `Optimize`
|
||||
// across all class instances.
|
||||
class TfDataTestOptimizer : public CustomGraphOptimizer {
|
||||
public:
|
||||
static void InitCount() { cnt_ = 0; }
|
||||
static int GetCount() { return cnt_; }
|
||||
static void InitCount() { count_ = 0; }
|
||||
static int GetCount() { return count_; }
|
||||
|
||||
TfDataTestOptimizer() {}
|
||||
string name() const override { return "tf_data_test_optimizer"; }
|
||||
TfDataTestOptimizer() = default;
|
||||
~TfDataTestOptimizer() override = default;
|
||||
TfDataTestOptimizer(const TfDataTestOptimizer&) = delete;
|
||||
TfDataTestOptimizer& operator=(const TfDataTestOptimizer& other) = delete;
|
||||
|
||||
std::string name() const override { return "tf_data_test_optimizer"; }
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
Status Init(
|
||||
@ -1037,7 +1044,7 @@ class TfDataTestOptimizer : public CustomGraphOptimizer {
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override {
|
||||
++cnt_;
|
||||
++count_;
|
||||
*optimized_graph = item.graph;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1046,69 +1053,128 @@ class TfDataTestOptimizer : public CustomGraphOptimizer {
|
||||
const GraphDef& optimized_graph, double result) override {}
|
||||
|
||||
private:
|
||||
static int cnt_;
|
||||
static std::atomic<int> count_;
|
||||
};
|
||||
|
||||
int TfDataTestOptimizer::cnt_;
|
||||
std::atomic<int> TfDataTestOptimizer::count_;
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer);
|
||||
|
||||
// Test fixture for parametrized testing.
|
||||
class TfDataTestFixture
|
||||
: public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
is_my_mul_tf_data_ = std::get<0>(GetParam());
|
||||
is_my_square_tf_data_ = std::get<1>(GetParam());
|
||||
}
|
||||
void RunTest();
|
||||
|
||||
private:
|
||||
// controls which of the functions is flagged as tf.data function
|
||||
bool is_my_mul_tf_data_ = false;
|
||||
bool is_my_square_tf_data_ = false;
|
||||
// Type for specifying how the inner function is nested inside the outer
|
||||
// function.
|
||||
enum class FuncNestingType {
|
||||
CallFromNode = 0,
|
||||
CallFromAttr = 1,
|
||||
CallFromList = 2
|
||||
};
|
||||
|
||||
TEST_P(TfDataTestFixture, TfDataTests) { RunTest(); }
|
||||
// Test fixture for parametrized testing.
|
||||
class TfDataTestFixture
|
||||
: public ::testing::TestWithParam<std::tuple<bool, bool, FuncNestingType>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
is_inner_func_tf_data_ = std::get<0>(GetParam());
|
||||
is_outer_func_tf_data_ = std::get<1>(GetParam());
|
||||
func_nesting_type_ = std::get<2>(GetParam());
|
||||
}
|
||||
// Controls which of the functions is flagged as tf.data function.
|
||||
bool is_inner_func_tf_data_ = false;
|
||||
bool is_outer_func_tf_data_ = false;
|
||||
// Controls how the inner function is nested inside the outer function.
|
||||
FuncNestingType func_nesting_type_ = FuncNestingType::CallFromNode;
|
||||
};
|
||||
|
||||
// Core test function.
|
||||
void TfDataTestFixture::RunTest() {
|
||||
// Helper functions for setting up the call of `inner_func` inside of
|
||||
// `outer_func`.
|
||||
|
||||
void SetUpCallFromNode(FunctionDef& outer_func) {
|
||||
// Call `inner_func` from a node in `outer_func`.
|
||||
outer_func = FunctionDefHelper::Create(
|
||||
"outer_func", {"x:float"}, {"z:float"}, {},
|
||||
/*node_def=*/
|
||||
{{{"inner_func"}, "inner_func", {"x", "x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "inner_func:z:0"}});
|
||||
}
|
||||
|
||||
void SetUpCallFromAttr(FunctionDef& outer_func) {
|
||||
// Call `inner_func` from an attribute in a node in `outer_func`.
|
||||
outer_func = FunctionDefHelper::Create(
|
||||
"outer_func", {"x:float"}, {"z:float"}, {},
|
||||
/*node_def=*/
|
||||
{{{"identity"},
|
||||
"Identity",
|
||||
{"x"},
|
||||
{{"T", DT_FLOAT},
|
||||
{"f", FunctionDefHelper::FunctionRef("inner_func", {})}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "x"}});
|
||||
}
|
||||
|
||||
void SetUpCallFromList(FunctionDef& outer_func) {
|
||||
// Call `inner_func` from a list attribute in a node in `outer_func`.
|
||||
outer_func = FunctionDefHelper::Create(
|
||||
"outer_func", {"x:float"}, {"z:float"}, {},
|
||||
/*node_def=*/
|
||||
{{{"identity"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "x"}});
|
||||
|
||||
// Add a list containing `inner_func` to the `identity` node.
|
||||
// `list_value` will be deallocated automatically since it is passed as
|
||||
// allocated list below.
|
||||
AttrValue_ListValue* list_value =
|
||||
(*outer_func.mutable_node_def(0)->mutable_attr())["list"].mutable_list();
|
||||
NameAttrList* entry = list_value->add_func();
|
||||
entry->set_name("inner_func");
|
||||
}
|
||||
|
||||
TEST_P(TfDataTestFixture, TfDataTests) {
|
||||
using test::function::NDef;
|
||||
|
||||
// Define function library:
|
||||
//
|
||||
// MyMul(x, y) = x * y
|
||||
// MySquare(x) = MyMul(x, x)
|
||||
// Define function library with `outer_func` and `inner_func`.
|
||||
|
||||
FunctionDef mul_func = FunctionDefHelper::Create(
|
||||
"MyMul", {"x:float", "y:float"}, {"z:float"}, {},
|
||||
FunctionDef inner_func = FunctionDefHelper::Create(
|
||||
"inner_func", {"x:float", "y:float"}, {"z:float"}, {},
|
||||
/*node_def=*/
|
||||
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "mul:z:0"}});
|
||||
(*mul_func.mutable_attr())[data::kTFDataFunction].set_b(is_my_mul_tf_data_);
|
||||
(*inner_func.mutable_attr())[data::kTFDataFunction].set_b(
|
||||
is_inner_func_tf_data_);
|
||||
|
||||
FunctionDef square_func = FunctionDefHelper::Create(
|
||||
"MySquare", {"x:float"}, {"z:float"}, {},
|
||||
{{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "my_mul:z:0"}});
|
||||
(*square_func.mutable_attr())[data::kTFDataFunction].set_b(
|
||||
is_my_square_tf_data_);
|
||||
FunctionDef outer_func;
|
||||
switch (func_nesting_type_) {
|
||||
case FuncNestingType::CallFromNode:
|
||||
SetUpCallFromNode(outer_func);
|
||||
break;
|
||||
case FuncNestingType::CallFromAttr:
|
||||
SetUpCallFromAttr(outer_func);
|
||||
break;
|
||||
case FuncNestingType::CallFromList:
|
||||
SetUpCallFromList(outer_func);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
(*outer_func.mutable_attr())[data::kTFDataFunction].set_b(
|
||||
is_outer_func_tf_data_);
|
||||
|
||||
// Tensorflow graph:
|
||||
//
|
||||
// a = tf.Placeholder(tf.float);
|
||||
// square = MySquare(a); // a^2
|
||||
// result = outer_func(a);
|
||||
GrapplerItem item;
|
||||
item.id = "tf_graph";
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
// Calls into function library
|
||||
NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
|
||||
NDef("outer_func_node", "outer_func", {"a"}, {{"T", DT_FLOAT}}, kDevice),
|
||||
// Forward outputs
|
||||
NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
NDef("out_s", "Identity", {"outer_func_node:0"}, {{"T", DT_FLOAT}},
|
||||
kDevice)},
|
||||
/*funcs=*/
|
||||
{mul_func, square_func});
|
||||
{inner_func, outer_func});
|
||||
|
||||
// Use only custom optimizer which counts its calls.
|
||||
TfDataTestOptimizer::InitCount();
|
||||
@ -1125,35 +1191,67 @@ void TfDataTestFixture::RunTest() {
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
// We expect one graph optimization + one optimization for each non-tf.data
|
||||
// function. Note that if `MySquare` is flagged as a tf.data function, then
|
||||
// `MyMul` is implicitly also considered a tf.data function because it is
|
||||
// called from `MySquare`.
|
||||
// function. Note that if `outer_func` is flagged as a tf.data function, then
|
||||
// `inner_func` is implicitly also considered a tf.data function because it is
|
||||
// called from `outer_func`.
|
||||
int expected_count = 3;
|
||||
if (is_my_square_tf_data_)
|
||||
expected_count -= 2;
|
||||
else if (is_my_mul_tf_data_)
|
||||
expected_count -= 1;
|
||||
if (is_outer_func_tf_data_)
|
||||
expected_count = 1;
|
||||
else if (is_inner_func_tf_data_)
|
||||
expected_count = 2;
|
||||
EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count);
|
||||
|
||||
// We expect that the tf.data-attribute has been propagated from `MySquare`
|
||||
// to its callee `MyMul` if the value is `true`. Otherwise, the attribute
|
||||
// We expect that the tf.data-attribute has been propagated from `outer_func`
|
||||
// to its callee `inner_func` if the value is `true`. Otherwise, the attribute
|
||||
// values should be unchanged.
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), output.library());
|
||||
const FunctionDef* square_func_after_opt = flib.Find("MySquare");
|
||||
const FunctionDef* mul_func_after_opt = flib.Find("MyMul");
|
||||
const FunctionDef* outer_func_after_opt = flib.Find("outer_func");
|
||||
const FunctionDef* inner_func_after_opt = flib.Find("inner_func");
|
||||
|
||||
EXPECT_EQ(data::IsTFDataFunction(*square_func_after_opt),
|
||||
is_my_square_tf_data_);
|
||||
if (is_my_square_tf_data_ || is_my_mul_tf_data_) {
|
||||
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), true);
|
||||
EXPECT_EQ(data::IsTFDataFunction(*outer_func_after_opt),
|
||||
is_outer_func_tf_data_);
|
||||
if (is_outer_func_tf_data_ || is_inner_func_tf_data_) {
|
||||
EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), true);
|
||||
} else {
|
||||
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), false);
|
||||
EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), false);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MetaOptimizerTest, TfDataTestFixture,
|
||||
::testing::Combine(::testing::Bool(),
|
||||
::testing::Bool()));
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
MetaOptimizerTest, TfDataTestFixture,
|
||||
::testing::Combine(::testing::Bool(), ::testing::Bool(),
|
||||
::testing::Values(FuncNestingType::CallFromNode,
|
||||
FuncNestingType::CallFromAttr,
|
||||
FuncNestingType::CallFromList)),
|
||||
[](const ::testing::TestParamInfo<TfDataTestFixture::ParamType>& info) {
|
||||
bool is_inner_func_tf_data = std::get<0>(info.param);
|
||||
bool is_outer_func_tf_data = std::get<1>(info.param);
|
||||
FuncNestingType func_nesting_type = std::get<2>(info.param);
|
||||
|
||||
std::string test_name;
|
||||
if (is_inner_func_tf_data && is_outer_func_tf_data)
|
||||
test_name = "both_funcs_tf_data";
|
||||
else if (is_inner_func_tf_data)
|
||||
test_name = "inner_func_tf_data";
|
||||
else if (is_outer_func_tf_data)
|
||||
test_name = "outer_func_tf_data";
|
||||
else
|
||||
test_name = "no_func_tf_data";
|
||||
switch (func_nesting_type) {
|
||||
case FuncNestingType::CallFromNode:
|
||||
test_name += "_call_from_node";
|
||||
break;
|
||||
case FuncNestingType::CallFromAttr:
|
||||
test_name += "_call_from_attribute";
|
||||
break;
|
||||
case FuncNestingType::CallFromList:
|
||||
test_name += "_call_from_list";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return test_name;
|
||||
});
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
|
Loading…
Reference in New Issue
Block a user