Copy tensorflow.python.saved_model.loader_impl.get_train_op to test file
#KERAS_PRIVATE_API_CLEANUP PiperOrigin-RevId: 339917969 Change-Id: Iadd13a8d23a941528e0384e090dc07ddc9d63da6
This commit is contained in:
parent
d266494953
commit
c31a3d0cce
@ -278,6 +278,13 @@ def load_model(sess, path, mode):
|
||||
return inputs, outputs, meta_graph_def
|
||||
|
||||
|
||||
def get_train_op(meta_graph_def):
|
||||
graph = ops.get_default_graph()
|
||||
signature_def = meta_graph_def.signature_def['__saved_model_train_op']
|
||||
op_name = signature_def.outputs['__saved_model_train_op'].name
|
||||
return graph.as_graph_element(op_name)
|
||||
|
||||
|
||||
class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def _save_model_dir(self, dirname='saved_model'):
|
||||
@ -402,7 +409,7 @@ class TestModelSavedModelExport(test.TestCase, parameterized.TestCase):
|
||||
self.assertIn('predictions/' + output_name, outputs)
|
||||
|
||||
# Train for a step
|
||||
train_op = loader_impl.get_train_op(meta_graph_def)
|
||||
train_op = get_train_op(meta_graph_def)
|
||||
train_outputs, _ = sess.run(
|
||||
[outputs, train_op], {inputs[input_name]: input_arr,
|
||||
inputs[target_name]: target_arr})
|
||||
|
Loading…
x
Reference in New Issue
Block a user