From c31a3d0cced49520361b95b78a173ebc0222af22 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Fri, 30 Oct 2020 12:24:11 -0700 Subject: [PATCH] Copy tensorflow.python.saved_model.loader_impl.get_train_op to test file #KERAS_PRIVATE_API_CLEANUP PiperOrigin-RevId: 339917969 Change-Id: Iadd13a8d23a941528e0384e090dc07ddc9d63da6 --- .../python/keras/saving/saved_model_experimental_test.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/saving/saved_model_experimental_test.py b/tensorflow/python/keras/saving/saved_model_experimental_test.py index 45130922250..d3e60ca0088 100644 --- a/tensorflow/python/keras/saving/saved_model_experimental_test.py +++ b/tensorflow/python/keras/saving/saved_model_experimental_test.py @@ -278,6 +278,13 @@ def load_model(sess, path, mode): return inputs, outputs, meta_graph_def +def get_train_op(meta_graph_def): + graph = ops.get_default_graph() + signature_def = meta_graph_def.signature_def['__saved_model_train_op'] + op_name = signature_def.outputs['__saved_model_train_op'].name + return graph.as_graph_element(op_name) + + class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): def _save_model_dir(self, dirname='saved_model'): @@ -402,7 +409,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): self.assertIn('predictions/' + output_name, outputs) # Train for a step - train_op = loader_impl.get_train_op(meta_graph_def) + train_op = get_train_op(meta_graph_def) train_outputs, _ = sess.run( [outputs, train_op], {inputs[input_name]: input_arr, inputs[target_name]: target_arr})