Refactoring and test coverage improvements

PiperOrigin-RevId: 338714669
Change-Id: Iffb1c652be6a3c44b9ecdafc89696231fc2ae904
This commit is contained in:
Michael Gester 2020-10-23 11:42:15 -07:00 committed by TensorFlower Gardener
parent 2c787e0028
commit 66b124ef0e
3 changed files with 171 additions and 73 deletions

View File

@ -76,8 +76,8 @@ class DatasetBase;
class SerializationContext;
inline bool IsTFDataFunction(const FunctionDef& func) {
return (func.attr().contains(data::kTFDataFunction) &&
func.attr().at(data::kTFDataFunction).b());
auto iter = func.attr().find(data::kTFDataFunction);
return (iter != func.attr().end() && iter->second.b());
}
// Interface for reading values from a key-value store.

View File

@ -615,9 +615,13 @@ Status MetaOptimizer::RunOptimizer(
void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
FunctionDefLibrary& fdef_lib) {
// Collect functions that need the attribute in this set.
absl::flat_hash_set<string> tf_data_functions;
std::function<void(const string&)> collect_tf_data_functions_dfs =
[&](const string& func_name) -> void {
absl::flat_hash_set<std::string> tf_data_functions;
std::function<void(const std::string&)> collect_tf_data_functions_dfs =
[&](const std::string& func_name) -> void {
const FunctionDef* func_def = flib.Find(func_name);
// Skip functions that are not reachable from the optimized graph.
if (func_def == nullptr) return;
// Return if we already found and added this function.
if (tf_data_functions.contains(func_name)) return;
@ -625,10 +629,6 @@ void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
// a tf.data function, so add it to the set.
tf_data_functions.insert(func_name);
const FunctionDef* func_def = flib.Find(func_name);
// Skip functions that are not reachable from the optimized graph.
if (func_def == nullptr) return;
// Proceed with DFS for functions called from current function.
for (const NodeDef& node : func_def->node_def()) {
if (flib.Contains(node.op())) {
@ -651,7 +651,7 @@ void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
};
// Perform DFS for all tf.data functions in `fdef_lib`.
for (const auto& func_def : fdef_lib.function()) {
const string& func_name = func_def.signature().name();
const std::string& func_name = func_def.signature().name();
if (data::IsTFDataFunction(func_def))
collect_tf_data_functions_dfs(func_name);
}
@ -659,7 +659,7 @@ void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
// because `FunctionLibraryDefinition` does not seem to provide mutable access
// to a `FunctionDef`.
for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
const string& func_name = func_def.signature().name();
const std::string& func_name = func_def.signature().name();
if (tf_data_functions.contains(func_name) &&
!data::IsTFDataFunction(func_def)) {
VLOG(2) << "Marking " << func_name << " as tf.data function";

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include <atomic>
#include "absl/strings/match.h"
#include "absl/strings/substitute.h"
#include "tensorflow/cc/ops/standard_ops.h"
@ -1020,14 +1022,19 @@ TEST_F(MetaOptimizerTest, CompressConstants) {
// Tests for checking expected behavior when skipping tf.data functions in
// meta optimizer.
// Custom optimizer which counts its calls.
// Custom optimizer which counts the number of calls of its method `Optimize`
// across all class instances.
class TfDataTestOptimizer : public CustomGraphOptimizer {
public:
static void InitCount() { cnt_ = 0; }
static int GetCount() { return cnt_; }
static void InitCount() { count_ = 0; }
static int GetCount() { return count_; }
TfDataTestOptimizer() {}
string name() const override { return "tf_data_test_optimizer"; }
TfDataTestOptimizer() = default;
~TfDataTestOptimizer() override = default;
TfDataTestOptimizer(const TfDataTestOptimizer&) = delete;
TfDataTestOptimizer& operator=(const TfDataTestOptimizer& other) = delete;
std::string name() const override { return "tf_data_test_optimizer"; }
bool UsesFunctionLibrary() const override { return false; }
Status Init(
@ -1037,7 +1044,7 @@ class TfDataTestOptimizer : public CustomGraphOptimizer {
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override {
++cnt_;
++count_;
*optimized_graph = item.graph;
return Status::OK();
}
@ -1046,69 +1053,128 @@ class TfDataTestOptimizer : public CustomGraphOptimizer {
const GraphDef& optimized_graph, double result) override {}
private:
static int cnt_;
static std::atomic<int> count_;
};
int TfDataTestOptimizer::cnt_;
std::atomic<int> TfDataTestOptimizer::count_;
REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer);
// Test fixture for parametrized testing.
class TfDataTestFixture
: public ::testing::TestWithParam<std::tuple<bool, bool>> {
protected:
void SetUp() override {
is_my_mul_tf_data_ = std::get<0>(GetParam());
is_my_square_tf_data_ = std::get<1>(GetParam());
}
void RunTest();
private:
// controls which of the functions is flagged as tf.data function
bool is_my_mul_tf_data_ = false;
bool is_my_square_tf_data_ = false;
// Type for specifying how the inner function is nested inside the outer
// function.
enum class FuncNestingType {
CallFromNode = 0,
CallFromAttr = 1,
CallFromList = 2
};
TEST_P(TfDataTestFixture, TfDataTests) { RunTest(); }
// Test fixture for parametrized testing.
class TfDataTestFixture
: public ::testing::TestWithParam<std::tuple<bool, bool, FuncNestingType>> {
protected:
void SetUp() override {
is_inner_func_tf_data_ = std::get<0>(GetParam());
is_outer_func_tf_data_ = std::get<1>(GetParam());
func_nesting_type_ = std::get<2>(GetParam());
}
// Controls which of the functions is flagged as tf.data function.
bool is_inner_func_tf_data_ = false;
bool is_outer_func_tf_data_ = false;
// Controls how the inner function is nested inside the outer function.
FuncNestingType func_nesting_type_ = FuncNestingType::CallFromNode;
};
// Core test function.
void TfDataTestFixture::RunTest() {
// Helper functions for setting up the call of `inner_func` inside of
// `outer_func`.
void SetUpCallFromNode(FunctionDef& outer_func) {
// Call `inner_func` from a node in `outer_func`.
outer_func = FunctionDefHelper::Create(
"outer_func", {"x:float"}, {"z:float"}, {},
/*node_def=*/
{{{"inner_func"}, "inner_func", {"x", "x"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "inner_func:z:0"}});
}
void SetUpCallFromAttr(FunctionDef& outer_func) {
// Call `inner_func` from an attribute in a node in `outer_func`.
outer_func = FunctionDefHelper::Create(
"outer_func", {"x:float"}, {"z:float"}, {},
/*node_def=*/
{{{"identity"},
"Identity",
{"x"},
{{"T", DT_FLOAT},
{"f", FunctionDefHelper::FunctionRef("inner_func", {})}}}},
/*ret_def=*/
{{"z", "x"}});
}
void SetUpCallFromList(FunctionDef& outer_func) {
// Call `inner_func` from a list attribute in a node in `outer_func`.
outer_func = FunctionDefHelper::Create(
"outer_func", {"x:float"}, {"z:float"}, {},
/*node_def=*/
{{{"identity"}, "Identity", {"x"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "x"}});
// Add a list containing `inner_func` to the `identity` node.
// `list_value` will be deallocated automatically since it is passed as
// allocated list below.
AttrValue_ListValue* list_value =
(*outer_func.mutable_node_def(0)->mutable_attr())["list"].mutable_list();
NameAttrList* entry = list_value->add_func();
entry->set_name("inner_func");
}
TEST_P(TfDataTestFixture, TfDataTests) {
using test::function::NDef;
// Define function library:
//
// MyMul(x, y) = x * y
// MySquare(x) = MyMul(x, x)
// Define function library with `outer_func` and `inner_func`.
FunctionDef mul_func = FunctionDefHelper::Create(
"MyMul", {"x:float", "y:float"}, {"z:float"}, {},
FunctionDef inner_func = FunctionDefHelper::Create(
"inner_func", {"x:float", "y:float"}, {"z:float"}, {},
/*node_def=*/
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "mul:z:0"}});
(*mul_func.mutable_attr())[data::kTFDataFunction].set_b(is_my_mul_tf_data_);
(*inner_func.mutable_attr())[data::kTFDataFunction].set_b(
is_inner_func_tf_data_);
FunctionDef square_func = FunctionDefHelper::Create(
"MySquare", {"x:float"}, {"z:float"}, {},
{{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "my_mul:z:0"}});
(*square_func.mutable_attr())[data::kTFDataFunction].set_b(
is_my_square_tf_data_);
FunctionDef outer_func;
switch (func_nesting_type_) {
case FuncNestingType::CallFromNode:
SetUpCallFromNode(outer_func);
break;
case FuncNestingType::CallFromAttr:
SetUpCallFromAttr(outer_func);
break;
case FuncNestingType::CallFromList:
SetUpCallFromList(outer_func);
break;
default:
break;
}
(*outer_func.mutable_attr())[data::kTFDataFunction].set_b(
is_outer_func_tf_data_);
// Tensorflow graph:
//
// a = tf.Placeholder(tf.float);
// square = MySquare(a); // a^2
// result = outer_func(a);
GrapplerItem item;
item.id = "tf_graph";
item.graph = test::function::GDef(
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
// Calls into function library
NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
NDef("outer_func_node", "outer_func", {"a"}, {{"T", DT_FLOAT}}, kDevice),
// Forward outputs
NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)},
NDef("out_s", "Identity", {"outer_func_node:0"}, {{"T", DT_FLOAT}},
kDevice)},
/*funcs=*/
{mul_func, square_func});
{inner_func, outer_func});
// Use only custom optimizer which counts its calls.
TfDataTestOptimizer::InitCount();
@ -1125,35 +1191,67 @@ void TfDataTestFixture::RunTest() {
TF_EXPECT_OK(status);
// We expect one graph optimization + one optimization for each non-tf.data
// function. Note that if `MySquare` is flagged as a tf.data function, then
// `MyMul` is implicitly also considered a tf.data function because it is
// called from `MySquare`.
// function. Note that if `outer_func` is flagged as a tf.data function, then
// `inner_func` is implicitly also considered a tf.data function because it is
// called from `outer_func`.
int expected_count = 3;
if (is_my_square_tf_data_)
expected_count -= 2;
else if (is_my_mul_tf_data_)
expected_count -= 1;
if (is_outer_func_tf_data_)
expected_count = 1;
else if (is_inner_func_tf_data_)
expected_count = 2;
EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count);
// We expect that the tf.data-attribute has been propagated from `MySquare`
// to its callee `MyMul` if the value is `true`. Otherwise, the attribute
// We expect that the tf.data-attribute has been propagated from `outer_func`
// to its callee `inner_func` if the value is `true`. Otherwise, the attribute
// values should be unchanged.
FunctionLibraryDefinition flib(OpRegistry::Global(), output.library());
const FunctionDef* square_func_after_opt = flib.Find("MySquare");
const FunctionDef* mul_func_after_opt = flib.Find("MyMul");
const FunctionDef* outer_func_after_opt = flib.Find("outer_func");
const FunctionDef* inner_func_after_opt = flib.Find("inner_func");
EXPECT_EQ(data::IsTFDataFunction(*square_func_after_opt),
is_my_square_tf_data_);
if (is_my_square_tf_data_ || is_my_mul_tf_data_) {
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), true);
EXPECT_EQ(data::IsTFDataFunction(*outer_func_after_opt),
is_outer_func_tf_data_);
if (is_outer_func_tf_data_ || is_inner_func_tf_data_) {
EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), true);
} else {
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), false);
EXPECT_EQ(data::IsTFDataFunction(*inner_func_after_opt), false);
}
}
INSTANTIATE_TEST_SUITE_P(MetaOptimizerTest, TfDataTestFixture,
::testing::Combine(::testing::Bool(),
::testing::Bool()));
INSTANTIATE_TEST_SUITE_P(
MetaOptimizerTest, TfDataTestFixture,
::testing::Combine(::testing::Bool(), ::testing::Bool(),
::testing::Values(FuncNestingType::CallFromNode,
FuncNestingType::CallFromAttr,
FuncNestingType::CallFromList)),
[](const ::testing::TestParamInfo<TfDataTestFixture::ParamType>& info) {
bool is_inner_func_tf_data = std::get<0>(info.param);
bool is_outer_func_tf_data = std::get<1>(info.param);
FuncNestingType func_nesting_type = std::get<2>(info.param);
std::string test_name;
if (is_inner_func_tf_data && is_outer_func_tf_data)
test_name = "both_funcs_tf_data";
else if (is_inner_func_tf_data)
test_name = "inner_func_tf_data";
else if (is_outer_func_tf_data)
test_name = "outer_func_tf_data";
else
test_name = "no_func_tf_data";
switch (func_nesting_type) {
case FuncNestingType::CallFromNode:
test_name += "_call_from_node";
break;
case FuncNestingType::CallFromAttr:
test_name += "_call_from_attribute";
break;
case FuncNestingType::CallFromList:
test_name += "_call_from_list";
break;
default:
break;
}
return test_name;
});
} // namespace
} // namespace grappler