Use tf.function's default autograph=True in saved_model/integration_test.
There is currently no demonstrable need to do something non-obvious here. PiperOrigin-RevId: 259782705
This commit is contained in:
parent
7251a1efe4
commit
d0eeef269d
@ -117,7 +117,7 @@ def wrap_keras_model_for_export(model, batch_input_shape,
|
|||||||
# the desired argspec.
|
# the desired argspec.
|
||||||
def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself?
|
def wrapped(*args, **kwargs): # TODO(arnoegw): Can we use call_fn itself?
|
||||||
return call_fn(*args, **kwargs)
|
return call_fn(*args, **kwargs)
|
||||||
traced_call_fn = tf.function(autograph=False)(
|
traced_call_fn = tf.function(
|
||||||
tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
|
tf_decorator.make_decorator(call_fn, wrapped, decorator_argspec=argspec))
|
||||||
|
|
||||||
# Now we need to trigger traces for all supported combinations of the
|
# Now we need to trigger traces for all supported combinations of the
|
||||||
|
@ -37,7 +37,7 @@ def main(argv):
|
|||||||
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
|
root.rnn_cell = tf.keras.layers.LSTMCell(units=10, recurrent_initializer=None)
|
||||||
|
|
||||||
# Wrap the rnn_cell.__call__ function and assign to next_state.
|
# Wrap the rnn_cell.__call__ function and assign to next_state.
|
||||||
root.next_state = tf.function(root.rnn_cell.__call__, autograph=False)
|
root.next_state = tf.function(root.rnn_cell.__call__)
|
||||||
|
|
||||||
# Wrap the rnn_cell.get_initial_function using a decorator and assign to an
|
# Wrap the rnn_cell.get_initial_function using a decorator and assign to an
|
||||||
# attribute with the same name.
|
# attribute with the same name.
|
||||||
|
Loading…
Reference in New Issue
Block a user