Internal change

PiperOrigin-RevId: 337234276
Change-Id: Idd6c4db5c9658a468e1d2758b2e84bbcf401e1a6
This commit is contained in:
A. Unique TensorFlower 2020-10-14 21:23:33 -07:00 committed by TensorFlower Gardener
parent 938de3ca21
commit 2ae0d560ee
6 changed files with 10 additions and 207 deletions

View File

@ -75,11 +75,6 @@ constexpr char kTFDataResourceTag[] = "tfdata";
class DatasetBase;
class SerializationContext;
inline bool IsTFDataFunction(const FunctionDef& func) {
return (func.attr().contains(data::kTFDataFunction) &&
func.attr().at(data::kTFDataFunction).b());
}
// Interface for reading values from a key-value store.
// Used for restoring iterator state. This class is thread safe.
// Please see comment on IteratorStateWriter for guidance around using the

View File

@ -717,7 +717,6 @@ tf_cuda_cc_test(
":custom_graph_optimizer_registry",
":meta_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",

View File

@ -430,8 +430,9 @@ FunctionDef* FuseFunctions(
const SetInputFn& set_input, const SetOutputFn& set_output,
const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
auto has_attrs = [](const FunctionDef& func) {
return !(func.attr_size() == 0 ||
(func.attr_size() == 1 && data::IsTFDataFunction(func)));
return !(
func.attr_size() == 0 ||
(func.attr_size() == 1 && func.attr().contains(data::kTFDataFunction)));
};
if (has_attrs(first_function) || has_attrs(second_function)) {
return nullptr; // Functions with attributes are currently not supported.

View File

@ -118,7 +118,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
for (const auto& name : flib.ListFunctionNames()) {
auto* func = flib.Find(name);
// Skip non tf.data functions.
if (!data::IsTFDataFunction(*func)) continue;
if (!func->attr().contains(data::kTFDataFunction)) continue;
VLOG(3) << "Optimize function: function=" << func->signature().name();
optimized_functions = true;

View File

@ -93,6 +93,10 @@ bool IsRunOnceOptimizer(const string& name) {
name == "auto_mixed_precision_mkl";
}
bool IsTFDataFunction(const FunctionDef& func) {
return func.attr().contains(data::kTFDataFunction);
}
// Creates a function library stub from a real function library: copy only
// signatures and attributes of all the function defined in fdef_lib. This stub
// can be swapped with real function library in a graph, before passing it to
@ -611,61 +615,6 @@ Status MetaOptimizer::RunOptimizer(
return Status::OK();
}
// Propagates _tf_data_function` attributes from functions to their callees.
void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
FunctionDefLibrary& fdef_lib) {
// Collect functions that need the attribute in this set.
absl::flat_hash_set<string> tf_data_functions;
std::function<void(const string&)> collect_tf_data_functions_dfs =
[&](const string& func_name) -> void {
// Return if we already found and added this function.
if (tf_data_functions.contains(func_name)) return;
// We only get here if the function is (directly or indirectly) called from
// a tf.data function, so add it to the set.
tf_data_functions.insert(func_name);
const FunctionDef* func_def = flib.Find(func_name);
// Skip functions that are not reachable from the optimized graph.
if (func_def == nullptr) return;
// Proceed with DFS for functions called from current function.
for (const NodeDef& node : func_def->node_def()) {
if (flib.Contains(node.op())) {
// This is a function call node.
collect_tf_data_functions_dfs(node.op());
}
// Check if there are functions in attributes.
for (const auto& attr : node.attr()) {
const AttrValue& attr_value = attr.second;
if (attr_value.has_func()) {
collect_tf_data_functions_dfs(attr_value.func().name());
}
if (attr_value.has_list()) {
for (const auto& func : attr_value.list().func()) {
collect_tf_data_functions_dfs(func.name());
}
}
}
}
};
// Perform DFS for all tf.data functions in `fdef_lib`.
for (const auto& func_def : fdef_lib.function()) {
const string& func_name = func_def.signature().name();
if (data::IsTFDataFunction(func_def))
collect_tf_data_functions_dfs(func_name);
}
// Set attribute for tf.data functions. We cannot do this in the DFS directly
// because `FunctionLibraryDefinition` does not seem to provide mutable access
// to a `FunctionDef`.
for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
const string& func_name = func_def.signature().name();
if (tf_data_functions.contains(func_name)) {
(*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
}
}
}
Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
GraphDef* optimized_graph) {
const uint64 start_us = Env::Default()->NowMicros();
@ -773,7 +722,6 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
for (const FunctionDef& function : optimized_graph->library().function()) {
find_xla_compiled_functions(function.node_def());
}
PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
// Optimize each function only once.
absl::flat_hash_set<string> optimized_funcs;
@ -799,9 +747,8 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
// the function optimizer, before we can optimize function body.
if (IsParametrized(func)) continue;
// Skip tf.data functions as they are optimized by tf.data meta optimizer
// and in function instantiation.
if (data::IsTFDataFunction(func)) continue;
// Skip tf.data functions as they are optimized by tf.data meta optimizer.
if (IsTFDataFunction(func)) continue;
VLOG(3) << "Optimize function: function=" << func_name << " ["
<< function_idx++ << " of "

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "absl/strings/match.h"
#include "absl/strings/substitute.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@ -1017,144 +1016,6 @@ TEST_F(MetaOptimizerTest, CompressConstants) {
}
}
// Tests for checking expected behavior when skipping tf.data functions in
// meta optimizer.
// Custom optimizer which counts its calls.
class TfDataTestOptimizer : public CustomGraphOptimizer {
public:
static void InitCount() { cnt_ = 0; }
static int GetCount() { return cnt_; }
TfDataTestOptimizer() {}
string name() const override { return "tf_data_test_optimizer"; }
bool UsesFunctionLibrary() const override { return false; }
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
return Status::OK();
}
Status Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) override {
++cnt_;
*optimized_graph = item.graph;
return Status::OK();
}
void Feedback(Cluster* cluster, const GrapplerItem& item,
const GraphDef& optimized_graph, double result) override {}
private:
static int cnt_;
};
int TfDataTestOptimizer::cnt_;
REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer);
// Test fixture for parametrized testing.
class TfDataTestFixture
: public ::testing::TestWithParam<std::tuple<bool, bool>> {
protected:
void SetUp() override {
is_my_mul_tf_data_ = std::get<0>(GetParam());
is_my_square_tf_data_ = std::get<1>(GetParam());
}
void RunTest();
private:
// controls which of the functions is flagged as tf.data function
bool is_my_mul_tf_data_ = false;
bool is_my_square_tf_data_ = false;
};
TEST_P(TfDataTestFixture, TfDataTests) { RunTest(); }
// Core test function.
void TfDataTestFixture::RunTest() {
using test::function::NDef;
// Define function library:
//
// MyMul(x, y) = x * y
// MySquare(x) = MyMul(x, x)
FunctionDef mul_func = FunctionDefHelper::Create(
"MyMul", {"x:float", "y:float"}, {"z:float"}, {},
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "mul:z:0"}});
(*mul_func.mutable_attr())[data::kTFDataFunction].set_b(is_my_mul_tf_data_);
FunctionDef square_func = FunctionDefHelper::Create(
"MySquare", {"x:float"}, {"z:float"}, {},
{{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", DT_FLOAT}}}},
/*ret_def=*/
{{"z", "my_mul:z:0"}});
(*square_func.mutable_attr())[data::kTFDataFunction].set_b(
is_my_square_tf_data_);
// Tensorflow graph:
//
// a = tf.Placeholder(tf.float);
// square = MySquare(a); // a^2
GrapplerItem item;
item.id = "tf_graph";
item.graph = test::function::GDef(
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
// Calls into function library
NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
// Forward outputs
NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)},
/*funcs=*/
{mul_func, square_func});
// Use only custom optimizer which counts its calls.
TfDataTestOptimizer::InitCount();
ConfigProto config_proto;
auto& rewriter_config =
*(config_proto.mutable_graph_options()->mutable_rewrite_options());
rewriter_config.add_optimizers("TfDataTestOptimizer");
rewriter_config.set_min_graph_nodes(-1);
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
MetaOptimizer optimizer(nullptr, config_proto);
GraphDef output;
const Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
// We expect one graph optimization + one optimization for each non-tf.data
// function. Note that if `MySquare` is flagged as a tf.data function, then
// `MyMul` is implicitly also considered a tf.data function because it is
// called from `MySquare`.
int expected_count = 3;
if (is_my_square_tf_data_)
expected_count -= 2;
else if (is_my_mul_tf_data_)
expected_count -= 1;
EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count);
// We expect that the tf.data-attribute has been propagated from `MySquare`
// to its callee `MyMul` if the value is `true`. Otherwise, the attribute
// values should be unchanged.
FunctionLibraryDefinition flib(OpRegistry::Global(), output.library());
const FunctionDef* square_func_after_opt = flib.Find("MySquare");
const FunctionDef* mul_func_after_opt = flib.Find("MyMul");
EXPECT_EQ(data::IsTFDataFunction(*square_func_after_opt),
is_my_square_tf_data_);
if (is_my_square_tf_data_ || is_my_mul_tf_data_) {
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), true);
} else {
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), false);
}
}
INSTANTIATE_TEST_SUITE_P(MetaOptimizerTest, TfDataTestFixture,
::testing::Combine(::testing::Bool(),
::testing::Bool()));
} // namespace
} // namespace grappler
} // namespace tensorflow