Run placer for the inlined function body.
PiperOrigin-RevId: 223878376
This commit is contained in:
parent
296e4a05e4
commit
50204e4205
@ -595,6 +595,13 @@ Status GraphExecutionState::OptimizeGraph(
|
||||
grappler::GrapplerItem item;
|
||||
item.id = "tf_graph";
|
||||
graph_->ToGraphDef(&item.graph);
|
||||
|
||||
// It's ok to skip invalid device annotations in Grappler.
|
||||
Status inferred_devices = item.InferDevicesFromGraph();
|
||||
if (!inferred_devices.ok()) {
|
||||
VLOG(3) << inferred_devices.error_message();
|
||||
}
|
||||
|
||||
// TODO(b/114748242): Add a unit test to test this bug fix.
|
||||
if (flib_def_) {
|
||||
*item.graph.mutable_library() = flib_def_->ToProto();
|
||||
|
@ -107,6 +107,8 @@ cc_library(
|
||||
":utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
@ -141,6 +143,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
@ -19,10 +19,13 @@ limitations under the License.
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
@ -38,7 +41,8 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
|
||||
restore_op = other.restore_op;
|
||||
save_restore_loc_tensor = other.save_restore_loc_tensor;
|
||||
queue_runners = other.queue_runners;
|
||||
allowed_optimizations = other.allowed_optimizations;
|
||||
devices_ = other.devices_;
|
||||
allowed_optimizations_ = other.allowed_optimizations_;
|
||||
graph.Swap(graph_def);
|
||||
}
|
||||
|
||||
@ -111,6 +115,64 @@ std::unordered_set<string> GrapplerItem::NodesToPreserve() const {
|
||||
return result;
|
||||
}
|
||||
|
||||
const std::unordered_set<string>& GrapplerItem::devices() const {
|
||||
return devices_;
|
||||
}
|
||||
|
||||
Status GrapplerItem::AddDevice(const string& device) {
|
||||
DeviceNameUtils::ParsedName name;
|
||||
|
||||
if (!DeviceNameUtils::ParseFullName(device, &name)) {
|
||||
return errors::InvalidArgument("Invalid device name: device=", device);
|
||||
|
||||
} else if (!name.has_job || !name.has_replica || !name.has_task ||
|
||||
!name.has_type || !name.has_id) {
|
||||
return errors::InvalidArgument("Not a fully defined device name: device=",
|
||||
device);
|
||||
}
|
||||
|
||||
devices_.insert(DeviceNameUtils::ParsedNameToString(name));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GrapplerItem::AddDevices(const GrapplerItem& other) {
|
||||
std::vector<absl::string_view> invalid_devices;
|
||||
for (const string& device : other.devices()) {
|
||||
Status added = AddDevice(device);
|
||||
if (!added.ok()) invalid_devices.emplace_back(device);
|
||||
}
|
||||
return invalid_devices.empty()
|
||||
? Status::OK()
|
||||
: errors::InvalidArgument("Skipped invalid devices: [",
|
||||
absl::StrJoin(invalid_devices, ", "),
|
||||
"]");
|
||||
}
|
||||
|
||||
Status GrapplerItem::InferDevicesFromGraph() {
|
||||
absl::flat_hash_set<absl::string_view> invalid_devices;
|
||||
for (const NodeDef& node : graph.node()) {
|
||||
Status added = AddDevice(node.device());
|
||||
if (!added.ok()) invalid_devices.insert(node.device());
|
||||
}
|
||||
VLOG(2) << "Inferred device set: [" << absl::StrJoin(devices_, ", ") << "]";
|
||||
return invalid_devices.empty()
|
||||
? Status::OK()
|
||||
: errors::InvalidArgument("Skipped invalid devices: [",
|
||||
absl::StrJoin(invalid_devices, ", "),
|
||||
"]");
|
||||
}
|
||||
|
||||
void GrapplerItem::ClearDevices() { devices_.clear(); }
|
||||
|
||||
const GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations()
|
||||
const {
|
||||
return allowed_optimizations_;
|
||||
}
|
||||
|
||||
GrapplerItem::AllowedOptimizations& GrapplerItem::allowed_optimizations() {
|
||||
return allowed_optimizations_;
|
||||
}
|
||||
|
||||
std::vector<const NodeDef*> ComputeTransitiveFanin(
|
||||
const GraphDef& graph, const std::vector<string>& terminal_nodes) {
|
||||
bool ill_formed = false;
|
||||
|
@ -85,7 +85,33 @@ struct GrapplerItem {
|
||||
bool non_differentiable_rewrites = true;
|
||||
};
|
||||
|
||||
AllowedOptimizations allowed_optimizations;
|
||||
const std::unordered_set<string>& devices() const;
|
||||
// Adds a device to a set of available devices, only if it's a valid fully
|
||||
// defined device name. Returns `Status::OK()` if successfully added a device,
|
||||
// and an error otherwise.
|
||||
Status AddDevice(const string& device);
|
||||
// Adds all valid devices from the other Grappler item to the device set.
|
||||
Status AddDevices(const GrapplerItem& other);
|
||||
// Adds all valid devices from the nodes of the graph to the device set.
|
||||
// Returns `Status::OK()` if all device annotations found in a graph are valid
|
||||
// fully defined device names, and an error otherwise.
|
||||
Status InferDevicesFromGraph();
|
||||
// Clears a set of available devices.
|
||||
void ClearDevices();
|
||||
|
||||
const AllowedOptimizations& allowed_optimizations() const;
|
||||
AllowedOptimizations& allowed_optimizations();
|
||||
|
||||
private:
|
||||
// TODO(ezhulenev) Make GrapplerItem a class and hide all public data members.
|
||||
// TODO(ezhulenev): Migrate all unordered collections to absl.
|
||||
|
||||
// A set of fully defined device names that can be used to place the nodes of
|
||||
// the `graph`.
|
||||
// Example of a fully defined name: "/job:work/replica:1/task:1/device:CPU:0"
|
||||
std::unordered_set<string> devices_;
|
||||
|
||||
AllowedOptimizations allowed_optimizations_;
|
||||
};
|
||||
|
||||
// Return the transitive fanin of a set of terminal nodes.
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -44,6 +46,32 @@ TEST_F(GrapplerItemTest, Basic) {
|
||||
EXPECT_EQ(main_ops, graph_nodes);
|
||||
}
|
||||
|
||||
TEST_F(GrapplerItemTest, InferDevices) {
|
||||
using test::function::NDef;
|
||||
|
||||
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
|
||||
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
|
||||
const string cpu2 = "/device:CPU:2";
|
||||
|
||||
GrapplerItem item;
|
||||
item.graph = test::function::GDef(
|
||||
{
|
||||
NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
|
||||
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
|
||||
NDef("c", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu2),
|
||||
},
|
||||
{} /* Empty function library */);
|
||||
|
||||
ASSERT_FALSE(item.InferDevicesFromGraph().ok());
|
||||
|
||||
EXPECT_EQ(item.devices().size(), 2);
|
||||
EXPECT_NE(item.devices().find(cpu0), item.devices().end());
|
||||
EXPECT_NE(item.devices().find(cpu1), item.devices().end());
|
||||
|
||||
item.ClearDevices();
|
||||
EXPECT_EQ(item.devices().size(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
@ -141,6 +141,7 @@ cc_library(
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
"//tensorflow/core:core_cpu_base",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -3572,7 +3572,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
|
||||
|
||||
// Disable restricted graph rewrites.
|
||||
options_.unary_ops_composition &=
|
||||
item.allowed_optimizations.non_differentiable_rewrites;
|
||||
item.allowed_optimizations().non_differentiable_rewrites;
|
||||
|
||||
if (options_.dedup_computations) {
|
||||
DedupComputations();
|
||||
|
@ -22,8 +22,11 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_replace.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/device_set.h"
|
||||
#include "tensorflow/core/common_runtime/function.h"
|
||||
#include "tensorflow/core/common_runtime/placer.h"
|
||||
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
@ -111,10 +114,26 @@ AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
|
||||
}
|
||||
}
|
||||
|
||||
class FakeCPUDevice : public Device {
|
||||
// This is a fake device that should not be used for any op kernel execution,
|
||||
// the only purpose of this device is to be passed as a part of DeviceSet to the
|
||||
// Placer.
|
||||
class FakeDevice : public Device {
|
||||
public:
|
||||
FakeCPUDevice(Env* env, const DeviceAttributes& attr) : Device(env, attr) {}
|
||||
FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
|
||||
explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
|
||||
Status Sync() override { return Status::OK(); }
|
||||
|
||||
private:
|
||||
static DeviceAttributes attr(const string& device) {
|
||||
DeviceNameUtils::ParsedName parsed_name;
|
||||
bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
|
||||
DCHECK(parsed) << "Failed to parse full device name: " << device;
|
||||
|
||||
DeviceAttributes attr;
|
||||
attr.set_name(device);
|
||||
attr.set_device_type(parsed_name.type);
|
||||
return attr;
|
||||
}
|
||||
};
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
@ -240,6 +259,7 @@ class FunctionOptimizerContext {
|
||||
graph_version_(item.graph.versions().producer()),
|
||||
opt_level_(opt_level),
|
||||
function_library_(OpRegistry::Global(), item.graph.library()),
|
||||
available_device_names_(item.devices().begin(), item.devices().end()),
|
||||
graph_view_(&item.graph) {
|
||||
InitializeTrulyConstNodes(item);
|
||||
InitializeFetchNodes(item);
|
||||
@ -275,6 +295,18 @@ class FunctionOptimizerContext {
|
||||
|
||||
const gtl::FlatSet<string>& fetch_tensors() const { return fetch_tensors_; }
|
||||
|
||||
const DeviceSet* devices() const {
|
||||
// Create fake devices lazily only if we need a DeviceSet.
|
||||
if (available_devices_.empty() && !available_device_names_.empty()) {
|
||||
for (const string& name : available_device_names_) {
|
||||
auto device = absl::make_unique<FakeDevice>(name);
|
||||
available_device_set_.AddDevice(device.get());
|
||||
available_devices_.push_back(std::move(device));
|
||||
}
|
||||
}
|
||||
return &available_device_set_;
|
||||
}
|
||||
|
||||
bool IsFetchNode(const string& node_name) const {
|
||||
return fetch_nodes_.find(node_name) != fetch_nodes_.end();
|
||||
}
|
||||
@ -350,11 +382,8 @@ class FunctionOptimizerContext {
|
||||
void InitializeFunctionLibraryRuntime() {
|
||||
if (!flr_) {
|
||||
Env* env = Env::Default();
|
||||
DeviceAttributes attr;
|
||||
attr.set_name("/device:CPU:0");
|
||||
attr.set_device_type("CPU");
|
||||
std::vector<std::unique_ptr<Device>> devices;
|
||||
devices.push_back(absl::make_unique<FakeCPUDevice>(env, attr));
|
||||
devices.push_back(absl::make_unique<FakeDevice>(env, "/device:CPU:0"));
|
||||
device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
|
||||
OptimizerOptions optimizer_opts;
|
||||
optimizer_opts.set_do_function_inlining(true);
|
||||
@ -375,6 +404,16 @@ class FunctionOptimizerContext {
|
||||
std::unique_ptr<ProcessFunctionLibraryRuntime> process_flr_;
|
||||
FunctionLibraryRuntime* flr_ = nullptr;
|
||||
|
||||
// Fully defined names of the devices available to the GrapplerItem.
|
||||
const gtl::FlatSet<string> available_device_names_;
|
||||
|
||||
// List of available `FakedDevices` (lazily initialized, see devices()).
|
||||
mutable std::vector<std::unique_ptr<Device>> available_devices_;
|
||||
|
||||
// DeviceSet of fake devices (`FakeDevice`) constructed from
|
||||
// available_devices_ (lazily initialized).
|
||||
mutable DeviceSet available_device_set_;
|
||||
|
||||
// Nodes that are Const and not in feed.
|
||||
std::unordered_map<string, const NodeDef*> truly_const_nodes_;
|
||||
// Specialized functions.
|
||||
@ -1305,7 +1344,58 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
// edges from inlined side-effectful ops.
|
||||
std::vector<string> side_effectful_nodes;
|
||||
|
||||
for (NodeDef& func_body_node : *item.mutable_function_body().mutable_node()) {
|
||||
// ------------------------------------------------------------------------ //
|
||||
// First we need to assign device placements to all function body nodes.
|
||||
|
||||
GraphDef placed_graph_def;
|
||||
|
||||
const DeviceSet* devices = ctx->devices();
|
||||
|
||||
if (devices->devices().empty()) {
|
||||
// If there are no devices available for placer, we just put all nodes to
|
||||
// the same device as a function caller node. This can happen if Grappler is
|
||||
// running "offline", without active runtime session, for example as a part
|
||||
// of a batch job for graph analysis/optimization.
|
||||
VLOG(3) << "Assign function call node device to all function body nodes. "
|
||||
<< "Device: " << func_node.device();
|
||||
placed_graph_def = item.mutable_function_body();
|
||||
for (NodeDef& node : *placed_graph_def.mutable_node()) {
|
||||
node.set_device(func_node.device());
|
||||
}
|
||||
} else {
|
||||
// If we are running in an active runtime session, Grappler will get the
|
||||
// graph after initial placing is done, and we should have devices for the
|
||||
// placer.
|
||||
VLOG(3) << "Run placer for instantiated function body. Devices: ["
|
||||
<< absl::StrJoin(
|
||||
devices->devices(), ", ",
|
||||
[](string* out, const Device* d) { out->append(d->name()); })
|
||||
<< "]";
|
||||
|
||||
// Construct a Graph object from the instantiated function body.
|
||||
GraphConstructorOptions opts;
|
||||
Graph graph(ctx->function_library());
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConvertGraphDefToGraph(opts, item.function_body(), &graph));
|
||||
|
||||
// Use function caller node device as a default for placer.
|
||||
const Device* default_device =
|
||||
devices->FindDeviceByName(func_node.device());
|
||||
|
||||
Placer placer(&graph, devices, nullptr, /* No session options */
|
||||
default_device);
|
||||
TF_RETURN_IF_ERROR(placer.Run());
|
||||
|
||||
// Convert Graph back to the GraphDef.
|
||||
graph.ToGraphDef(&placed_graph_def);
|
||||
}
|
||||
|
||||
// ------------------------------------------------------------------------ //
|
||||
// After all nodes placed we need to prepare them for inlining into the
|
||||
// optimized graph: turn placeholders into identities, update nodes
|
||||
// connectivity, etc...
|
||||
|
||||
for (NodeDef& func_body_node : *placed_graph_def.mutable_node()) {
|
||||
if (item.IsInputPlaceholder(func_body_node.name())) {
|
||||
// Turn input placeholders into identity node.
|
||||
DCHECK_EQ(0, func_body_node.input_size());
|
||||
@ -1337,10 +1427,6 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &func_body_node));
|
||||
|
||||
// TODO(ezhulenev): Call PartitionedCallOp to get placement for all function
|
||||
// body nodes, for now just place it on function caller node device.
|
||||
func_body_node.set_device(func_node.device());
|
||||
|
||||
// If the function body has a side-effectful op, we double check that the
|
||||
// function call node has an output control edge, otherwise we can't safely
|
||||
// do inlining and guarantee that node will be executed.
|
||||
@ -1362,7 +1448,7 @@ Status InlineIndirectFunctionCall(const NodeDef& func_node,
|
||||
|
||||
// TODO(ezhulenev): Inline nested indirect function calls.
|
||||
|
||||
// Move the node to the main graph.
|
||||
// Move the node to the optimized graph.
|
||||
optimized_graph->add_node()->Swap(&func_body_node);
|
||||
}
|
||||
|
||||
|
@ -881,6 +881,59 @@ TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithControlDependencies) {
|
||||
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
|
||||
}
|
||||
|
||||
TEST_F(FunctionOptimizerTest, InlineIndirectFunctionWithDevicePlacement) {
|
||||
using test::function::NDef;
|
||||
using FDH = FunctionDefHelper;
|
||||
|
||||
FunctionOptimizer optimizer(RewriterConfig::AGGRESSIVE);
|
||||
|
||||
FunctionDef mul_func = FunctionDefHelper::Create(
|
||||
"MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
|
||||
{{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
|
||||
/* Mapping between function returns and function node outputs. */
|
||||
{{"z", "mul:z:0"}});
|
||||
// Add device placement spec to the function body node.
|
||||
(*mul_func.mutable_node_def())[0].set_device("/device:CPU:1");
|
||||
|
||||
// We need fully defined device names to run the placer for inlined function.
|
||||
const string cpu0 = "/job:work/replica:1/task:1/device:CPU:0";
|
||||
const string cpu1 = "/job:work/replica:1/task:1/device:CPU:1";
|
||||
|
||||
// Build a graph to compute c = MyMul(a, b)
|
||||
GrapplerItem item;
|
||||
item.fetch = {"d"};
|
||||
item.graph = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
|
||||
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
|
||||
NDef("c", "PartitionedCall", {"a", "b"},
|
||||
{{"Tin", DataTypeSlice{DT_FLOAT, DT_FLOAT}},
|
||||
{"Tout", DataTypeSlice{DT_FLOAT}},
|
||||
{"f", FDH::FunctionRef("MyMul", {{"T", DT_FLOAT}})}},
|
||||
cpu0),
|
||||
NDef("d", "Identity", {"c"}, {{"T", DT_FLOAT}}, cpu0)},
|
||||
// Function library.
|
||||
{mul_func});
|
||||
ASSERT_TRUE(item.InferDevicesFromGraph().ok());
|
||||
|
||||
GraphDef optimized_graph;
|
||||
TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &optimized_graph));
|
||||
|
||||
GraphDef expected = test::function::GDef(
|
||||
{NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu0),
|
||||
NDef("b", "Placeholder", {}, {{"dtype", DT_FLOAT}}, cpu1),
|
||||
|
||||
// Function must be inlined and `mul` node placed on a requested device.
|
||||
NDef("c/x", "Identity", {"a:0"}, {{"T", DT_FLOAT}}, cpu1),
|
||||
NDef("c/y", "Identity", {"b:0"}, {{"T", DT_FLOAT}}, cpu1),
|
||||
NDef("c/mul", "Mul", {"c/x", "c/y"}, {{"T", DT_FLOAT}}, cpu1),
|
||||
|
||||
NDef("d", "Identity", {"c/mul:0"}, {{"T", DT_FLOAT}}, cpu0)},
|
||||
// Function library.
|
||||
{mul_func});
|
||||
|
||||
CompareGraphs(expected, optimized_graph);
|
||||
}
|
||||
|
||||
TEST_F(FunctionOptimizerTest, SpecializeFunctionXTimesTwo) {
|
||||
using test::function::NDef;
|
||||
|
||||
|
@ -524,7 +524,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
// can't perform non-differentiable rewrites.
|
||||
if (differentiable_functions.find(func_name) !=
|
||||
differentiable_functions.end()) {
|
||||
func_item.allowed_optimizations.non_differentiable_rewrites = false;
|
||||
func_item.allowed_optimizations().non_differentiable_rewrites = false;
|
||||
}
|
||||
|
||||
// Function item is allowed to use all devices from the main graph.
|
||||
Status added_devices = func_item.AddDevices(item);
|
||||
if (!added_devices.ok()) {
|
||||
VLOG(3) << added_devices.error_message();
|
||||
}
|
||||
|
||||
// Optimize function body graph.
|
||||
|
@ -108,7 +108,7 @@ class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
|
||||
GraphDef* optimized_graph) override {
|
||||
*optimized_graph = item.graph;
|
||||
if (allowed_optimizations_) {
|
||||
allowed_optimizations_->insert({item.id, item.allowed_optimizations});
|
||||
allowed_optimizations_->insert({item.id, item.allowed_optimizations()});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -114,6 +114,8 @@ void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) const {
|
||||
for (int i = 0; i < want.node_size(); ++i) {
|
||||
EXPECT_EQ(want.node(i).op(), got.node(i).op());
|
||||
EXPECT_EQ(want.node(i).name(), got.node(i).name());
|
||||
EXPECT_EQ(want.node(i).device(), got.node(i).device());
|
||||
|
||||
ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size());
|
||||
for (int j = 0; j < want.node(i).input_size(); ++j) {
|
||||
const TensorId want_tensor = ParseTensorName(want.node(i).input(j));
|
||||
|
Loading…
x
Reference in New Issue
Block a user