From 51b334875125b1e76545d02d2e8e18c7ff2be0af Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 22 Feb 2018 12:18:39 -0800 Subject: [PATCH] Measure the performance of the original placement to ensure that we preserve it in case the placer isn't given enough time to find a better solution. PiperOrigin-RevId: 186655094 --- tensorflow/python/grappler/graph_placer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py index 2cc35367925..1cd51df4d96 100644 --- a/tensorflow/python/grappler/graph_placer.py +++ b/tensorflow/python/grappler/graph_placer.py @@ -68,6 +68,16 @@ def PlaceGraph(metagraph, item = gitem.Item(optimized_metagraph) + # Measure the runtime achievable with the original placement. + try: + _, original_run_time, _ = cluster.MeasureCosts(item) + if verbose: + print("Runtime for original placement: " + str(original_run_time)) + except errors.OpError as e: + if verbose: + print("Original placement isn't feasible: " + str(e)) + original_run_time = hparams.failing_signal + if hparams is None: hparams = hierarchical_controller.hierarchical_controller_hparams() # We run with a single child @@ -98,7 +108,7 @@ def PlaceGraph(metagraph, print("Failed to run graph:" + str(e)) run_time = hparams.failing_signal updated = model.update_reward(sess, run_time, verbose=verbose) - if updated: + if updated and run_time < original_run_time: if verbose: print("Found better placement, with runtime " + str(run_time)) model.export_placement(metagraph)