Copy tensorflow.python.saved_model.loader_impl.get_train_op to test file
#KERAS_PRIVATE_API_CLEANUP PiperOrigin-RevId: 339917969 Change-Id: Iadd13a8d23a941528e0384e090dc07ddc9d63da6
This commit is contained in:
parent
d266494953
commit
c31a3d0cce
@ -278,6 +278,13 @@ def load_model(sess, path, mode):
|
|||||||
return inputs, outputs, meta_graph_def
|
return inputs, outputs, meta_graph_def
|
||||||
|
|
||||||
|
|
||||||
|
def get_train_op(meta_graph_def):
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
signature_def = meta_graph_def.signature_def['__saved_model_train_op']
|
||||||
|
op_name = signature_def.outputs['__saved_model_train_op'].name
|
||||||
|
return graph.as_graph_element(op_name)
|
||||||
|
|
||||||
|
|
||||||
class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _save_model_dir(self, dirname='saved_model'):
|
def _save_model_dir(self, dirname='saved_model'):
|
||||||
@ -402,7 +409,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
|||||||
self.assertIn('predictions/' + output_name, outputs)
|
self.assertIn('predictions/' + output_name, outputs)
|
||||||
|
|
||||||
# Train for a step
|
# Train for a step
|
||||||
train_op = loader_impl.get_train_op(meta_graph_def)
|
train_op = get_train_op(meta_graph_def)
|
||||||
train_outputs, _ = sess.run(
|
train_outputs, _ = sess.run(
|
||||||
[outputs, train_op], {inputs[input_name]: input_arr,
|
[outputs, train_op], {inputs[input_name]: input_arr,
|
||||||
inputs[target_name]: target_arr})
|
inputs[target_name]: target_arr})
|
||||||
|
Loading…
x
Reference in New Issue
Block a user