Internal change
PiperOrigin-RevId: 337234276 Change-Id: Idd6c4db5c9658a468e1d2758b2e84bbcf401e1a6
This commit is contained in:
parent
938de3ca21
commit
2ae0d560ee
@ -75,11 +75,6 @@ constexpr char kTFDataResourceTag[] = "tfdata";
|
||||
class DatasetBase;
|
||||
class SerializationContext;
|
||||
|
||||
inline bool IsTFDataFunction(const FunctionDef& func) {
|
||||
return (func.attr().contains(data::kTFDataFunction) &&
|
||||
func.attr().at(data::kTFDataFunction).b());
|
||||
}
|
||||
|
||||
// Interface for reading values from a key-value store.
|
||||
// Used for restoring iterator state. This class is thread safe.
|
||||
// Please see comment on IteratorStateWriter for guidance around using the
|
||||
|
@ -717,7 +717,6 @@ tf_cuda_cc_test(
|
||||
":custom_graph_optimizer_registry",
|
||||
":meta_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -430,8 +430,9 @@ FunctionDef* FuseFunctions(
|
||||
const SetInputFn& set_input, const SetOutputFn& set_output,
|
||||
const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
|
||||
auto has_attrs = [](const FunctionDef& func) {
|
||||
return !(func.attr_size() == 0 ||
|
||||
(func.attr_size() == 1 && data::IsTFDataFunction(func)));
|
||||
return !(
|
||||
func.attr_size() == 0 ||
|
||||
(func.attr_size() == 1 && func.attr().contains(data::kTFDataFunction)));
|
||||
};
|
||||
if (has_attrs(first_function) || has_attrs(second_function)) {
|
||||
return nullptr; // Functions with attributes are currently not supported.
|
||||
|
@ -118,7 +118,7 @@ Status TFDataMetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
for (const auto& name : flib.ListFunctionNames()) {
|
||||
auto* func = flib.Find(name);
|
||||
// Skip non tf.data functions.
|
||||
if (!data::IsTFDataFunction(*func)) continue;
|
||||
if (!func->attr().contains(data::kTFDataFunction)) continue;
|
||||
VLOG(3) << "Optimize function: function=" << func->signature().name();
|
||||
optimized_functions = true;
|
||||
|
||||
|
@ -93,6 +93,10 @@ bool IsRunOnceOptimizer(const string& name) {
|
||||
name == "auto_mixed_precision_mkl";
|
||||
}
|
||||
|
||||
bool IsTFDataFunction(const FunctionDef& func) {
|
||||
return func.attr().contains(data::kTFDataFunction);
|
||||
}
|
||||
|
||||
// Creates a function library stub from a real function library: copy only
|
||||
// signatures and attributes of all the function defined in fdef_lib. This stub
|
||||
// can be swapped with real function library in a graph, before passing it to
|
||||
@ -611,61 +615,6 @@ Status MetaOptimizer::RunOptimizer(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Propagates _tf_data_function` attributes from functions to their callees.
|
||||
void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
|
||||
FunctionDefLibrary& fdef_lib) {
|
||||
// Collect functions that need the attribute in this set.
|
||||
absl::flat_hash_set<string> tf_data_functions;
|
||||
std::function<void(const string&)> collect_tf_data_functions_dfs =
|
||||
[&](const string& func_name) -> void {
|
||||
// Return if we already found and added this function.
|
||||
if (tf_data_functions.contains(func_name)) return;
|
||||
|
||||
// We only get here if the function is (directly or indirectly) called from
|
||||
// a tf.data function, so add it to the set.
|
||||
tf_data_functions.insert(func_name);
|
||||
|
||||
const FunctionDef* func_def = flib.Find(func_name);
|
||||
// Skip functions that are not reachable from the optimized graph.
|
||||
if (func_def == nullptr) return;
|
||||
|
||||
// Proceed with DFS for functions called from current function.
|
||||
for (const NodeDef& node : func_def->node_def()) {
|
||||
if (flib.Contains(node.op())) {
|
||||
// This is a function call node.
|
||||
collect_tf_data_functions_dfs(node.op());
|
||||
}
|
||||
// Check if there are functions in attributes.
|
||||
for (const auto& attr : node.attr()) {
|
||||
const AttrValue& attr_value = attr.second;
|
||||
if (attr_value.has_func()) {
|
||||
collect_tf_data_functions_dfs(attr_value.func().name());
|
||||
}
|
||||
if (attr_value.has_list()) {
|
||||
for (const auto& func : attr_value.list().func()) {
|
||||
collect_tf_data_functions_dfs(func.name());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
// Perform DFS for all tf.data functions in `fdef_lib`.
|
||||
for (const auto& func_def : fdef_lib.function()) {
|
||||
const string& func_name = func_def.signature().name();
|
||||
if (data::IsTFDataFunction(func_def))
|
||||
collect_tf_data_functions_dfs(func_name);
|
||||
}
|
||||
// Set attribute for tf.data functions. We cannot do this in the DFS directly
|
||||
// because `FunctionLibraryDefinition` does not seem to provide mutable access
|
||||
// to a `FunctionDef`.
|
||||
for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
|
||||
const string& func_name = func_def.signature().name();
|
||||
if (tf_data_functions.contains(func_name)) {
|
||||
(*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
|
||||
GraphDef* optimized_graph) {
|
||||
const uint64 start_us = Env::Default()->NowMicros();
|
||||
@ -773,7 +722,6 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
|
||||
for (const FunctionDef& function : optimized_graph->library().function()) {
|
||||
find_xla_compiled_functions(function.node_def());
|
||||
}
|
||||
PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
|
||||
|
||||
// Optimize each function only once.
|
||||
absl::flat_hash_set<string> optimized_funcs;
|
||||
@ -799,9 +747,8 @@ Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
|
||||
// the function optimizer, before we can optimize function body.
|
||||
if (IsParametrized(func)) continue;
|
||||
|
||||
// Skip tf.data functions as they are optimized by tf.data meta optimizer
|
||||
// and in function instantiation.
|
||||
if (data::IsTFDataFunction(func)) continue;
|
||||
// Skip tf.data functions as they are optimized by tf.data meta optimizer.
|
||||
if (IsTFDataFunction(func)) continue;
|
||||
|
||||
VLOG(3) << "Optimize function: function=" << func_name << " ["
|
||||
<< function_idx++ << " of "
|
||||
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
@ -1017,144 +1016,6 @@ TEST_F(MetaOptimizerTest, CompressConstants) {
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for checking expected behavior when skipping tf.data functions in
|
||||
// meta optimizer.
|
||||
|
||||
// Custom optimizer which counts its calls.
|
||||
class TfDataTestOptimizer : public CustomGraphOptimizer {
|
||||
public:
|
||||
static void InitCount() { cnt_ = 0; }
|
||||
static int GetCount() { return cnt_; }
|
||||
|
||||
TfDataTestOptimizer() {}
|
||||
string name() const override { return "tf_data_test_optimizer"; }
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
Status Init(
|
||||
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override {
|
||||
++cnt_;
|
||||
*optimized_graph = item.graph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override {}
|
||||
|
||||
private:
|
||||
static int cnt_;
|
||||
};
|
||||
|
||||
int TfDataTestOptimizer::cnt_;
|
||||
|
||||
REGISTER_GRAPH_OPTIMIZER(TfDataTestOptimizer);
|
||||
|
||||
// Test fixture for parametrized testing.
|
||||
class TfDataTestFixture
|
||||
: public ::testing::TestWithParam<std::tuple<bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
is_my_mul_tf_data_ = std::get<0>(GetParam());
|
||||
is_my_square_tf_data_ = std::get<1>(GetParam());
|
||||
}
|
||||
void RunTest();
|
||||
|
||||
private:
|
||||
// controls which of the functions is flagged as tf.data function
|
||||
bool is_my_mul_tf_data_ = false;
|
||||
bool is_my_square_tf_data_ = false;
|
||||
};
|
||||
|
||||
TEST_P(TfDataTestFixture, TfDataTests) { RunTest(); }
|
||||
|
||||
// Core test function.
|
||||
void TfDataTestFixture::RunTest() {
|
||||
using test::function::NDef;
|
||||
|
||||
// Define function library:
|
||||
//
|
||||
// MyMul(x, y) = x * y
|
||||
// MySquare(x) = MyMul(x, x)
|
||||
|
||||
FunctionDef mul_func = FunctionDefHelper::Create(
|
||||
"MyMul", {"x:float", "y:float"}, {"z:float"}, {},
|
||||
{{{"mul"}, "Mul", {"x", "y"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "mul:z:0"}});
|
||||
(*mul_func.mutable_attr())[data::kTFDataFunction].set_b(is_my_mul_tf_data_);
|
||||
|
||||
FunctionDef square_func = FunctionDefHelper::Create(
|
||||
"MySquare", {"x:float"}, {"z:float"}, {},
|
||||
{{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", DT_FLOAT}}}},
|
||||
/*ret_def=*/
|
||||
{{"z", "my_mul:z:0"}});
|
||||
(*square_func.mutable_attr())[data::kTFDataFunction].set_b(
|
||||
is_my_square_tf_data_);
|
||||
|
||||
// Tensorflow graph:
|
||||
//
|
||||
// a = tf.Placeholder(tf.float);
|
||||
// square = MySquare(a); // a^2
|
||||
GrapplerItem item;
|
||||
item.id = "tf_graph";
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
|
||||
// Calls into function library
|
||||
NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
|
||||
// Forward outputs
|
||||
NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice)},
|
||||
/*funcs=*/
|
||||
{mul_func, square_func});
|
||||
|
||||
// Use only custom optimizer which counts its calls.
|
||||
TfDataTestOptimizer::InitCount();
|
||||
ConfigProto config_proto;
|
||||
auto& rewriter_config =
|
||||
*(config_proto.mutable_graph_options()->mutable_rewrite_options());
|
||||
rewriter_config.add_optimizers("TfDataTestOptimizer");
|
||||
rewriter_config.set_min_graph_nodes(-1);
|
||||
rewriter_config.set_meta_optimizer_iterations(RewriterConfig::ONE);
|
||||
|
||||
MetaOptimizer optimizer(nullptr, config_proto);
|
||||
GraphDef output;
|
||||
const Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
// We expect one graph optimization + one optimization for each non-tf.data
|
||||
// function. Note that if `MySquare` is flagged as a tf.data function, then
|
||||
// `MyMul` is implicitly also considered a tf.data function because it is
|
||||
// called from `MySquare`.
|
||||
int expected_count = 3;
|
||||
if (is_my_square_tf_data_)
|
||||
expected_count -= 2;
|
||||
else if (is_my_mul_tf_data_)
|
||||
expected_count -= 1;
|
||||
EXPECT_EQ(TfDataTestOptimizer::GetCount(), expected_count);
|
||||
|
||||
// We expect that the tf.data-attribute has been propagated from `MySquare`
|
||||
// to its callee `MyMul` if the value is `true`. Otherwise, the attribute
|
||||
// values should be unchanged.
|
||||
FunctionLibraryDefinition flib(OpRegistry::Global(), output.library());
|
||||
const FunctionDef* square_func_after_opt = flib.Find("MySquare");
|
||||
const FunctionDef* mul_func_after_opt = flib.Find("MyMul");
|
||||
|
||||
EXPECT_EQ(data::IsTFDataFunction(*square_func_after_opt),
|
||||
is_my_square_tf_data_);
|
||||
if (is_my_square_tf_data_ || is_my_mul_tf_data_) {
|
||||
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), true);
|
||||
} else {
|
||||
EXPECT_EQ(data::IsTFDataFunction(*mul_func_after_opt), false);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MetaOptimizerTest, TfDataTestFixture,
|
||||
::testing::Combine(::testing::Bool(),
|
||||
::testing::Bool()));
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user