Remove infeed from TPU Strategy when using host training loop.
PiperOrigin-RevId: 230877075
This commit is contained in:
parent
bdca83b71c
commit
53d3a75b7b
@ -21,6 +21,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
|
|
||||||
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
from tensorflow.contrib.tpu.python.ops import tpu_ops
|
||||||
@ -42,6 +43,7 @@ from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver a
|
|||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.eager import tape
|
from tensorflow.python.eager import tape
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import device as tf_device
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
@ -98,9 +100,9 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
|
|||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
# Figure out what collections this variable should be added to.
|
# Figure out what collections this variable should be added to.
|
||||||
# We'll add the TPUMirroredVariable to those collections instead.
|
# We'll add the TPUMirroredVariable to those collections instead.
|
||||||
collections = kwargs.pop("collections", None)
|
var_collections = kwargs.pop("collections", None)
|
||||||
if collections is None:
|
if var_collections is None:
|
||||||
collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
var_collections = [ops.GraphKeys.GLOBAL_VARIABLES]
|
||||||
kwargs["collections"] = []
|
kwargs["collections"] = []
|
||||||
|
|
||||||
# TODO(jhseu): Should we have different behavior for different
|
# TODO(jhseu): Should we have different behavior for different
|
||||||
@ -139,11 +141,11 @@ def _create_tpu_mirrored_variable( # pylint: disable=missing-docstring
|
|||||||
# "trainable" to False for next_creator() since that causes functions
|
# "trainable" to False for next_creator() since that causes functions
|
||||||
# like implicit_gradients to skip those variables.
|
# like implicit_gradients to skip those variables.
|
||||||
if kwargs.get("trainable", True):
|
if kwargs.get("trainable", True):
|
||||||
collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
var_collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
|
l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||||
for v in value_list:
|
for v in value_list:
|
||||||
l.remove(v)
|
l.remove(v)
|
||||||
g.add_to_collections(collections, result)
|
g.add_to_collections(var_collections, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
@ -170,19 +172,16 @@ class TPUStrategy(distribute_lib.DistributionStrategy):
|
|||||||
the usecase of using a single core within a TPU cluster.
|
the usecase of using a single core within a TPU cluster.
|
||||||
**kwargs: Additional experimental flags. Will be removed in future.
|
**kwargs: Additional experimental flags. Will be removed in future.
|
||||||
"""
|
"""
|
||||||
super(TPUStrategy, self).__init__(TPUExtended(
|
|
||||||
self, tpu_cluster_resolver, steps_per_run, device_assignment))
|
|
||||||
|
|
||||||
self._disable_training_loop_on_host = False
|
|
||||||
if len(kwargs) > 1:
|
if len(kwargs) > 1:
|
||||||
raise ValueError("TPUStrategy constructor only takes one experimental "
|
raise ValueError("TPUStrategy constructor only takes one experimental "
|
||||||
"flag now")
|
"flag now")
|
||||||
if len(kwargs) == 1:
|
elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs:
|
||||||
if "_disable_training_loop_on_host" not in kwargs:
|
raise ValueError("TPUStrategy constructor does not support arguments: "
|
||||||
raise ValueError("TPUStrategy constructor does not support arguments: "
|
"{}".format(kwargs))
|
||||||
"{}".format(kwargs))
|
|
||||||
self._disable_training_loop_on_host = (
|
super(TPUStrategy, self).__init__(TPUExtended(
|
||||||
kwargs["_disable_training_loop_on_host"])
|
self, tpu_cluster_resolver, steps_per_run, device_assignment,
|
||||||
|
kwargs.get("_disable_training_loop_on_host", False)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def steps_per_run(self):
|
def steps_per_run(self):
|
||||||
@ -197,7 +196,8 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
container_strategy,
|
container_strategy,
|
||||||
tpu_cluster_resolver=None,
|
tpu_cluster_resolver=None,
|
||||||
steps_per_run=None,
|
steps_per_run=None,
|
||||||
device_assignment=None):
|
device_assignment=None,
|
||||||
|
disable_training_loop_on_host=False):
|
||||||
super(TPUExtended, self).__init__(container_strategy)
|
super(TPUExtended, self).__init__(container_strategy)
|
||||||
|
|
||||||
if tpu_cluster_resolver is None:
|
if tpu_cluster_resolver is None:
|
||||||
@ -211,6 +211,7 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
self._tpu_cluster_resolver = tpu_cluster_resolver
|
self._tpu_cluster_resolver = tpu_cluster_resolver
|
||||||
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
|
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
|
||||||
self._device_assignment = device_assignment
|
self._device_assignment = device_assignment
|
||||||
|
self._disable_training_loop_on_host = disable_training_loop_on_host
|
||||||
|
|
||||||
# Device assignment is currently only supported for 1 core case.
|
# Device assignment is currently only supported for 1 core case.
|
||||||
if self._device_assignment:
|
if self._device_assignment:
|
||||||
@ -238,15 +239,25 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
|
||||||
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
|
||||||
|
|
||||||
# For input:
|
# If the training loop is on the device, we must use the infeed, with input
|
||||||
input_device_map = values.ReplicaDeviceMap(tuple(
|
# on the host. Otherwise, we preload the data onto the TPUs.
|
||||||
self.get_host_cpu_device(hid) for hid in range(self.num_hosts)))
|
if disable_training_loop_on_host:
|
||||||
worker_devices = [
|
input_device_map = values.ReplicaDeviceMap(tuple(
|
||||||
(self.get_host(hid), [self.get_host_cpu_device(hid)])
|
self.get_host_cpu_device(hid) for hid in range(self.num_hosts)))
|
||||||
for hid in range(self.num_hosts)
|
worker_devices = [
|
||||||
]
|
(self.get_host(hid), [self.get_host_cpu_device(hid)])
|
||||||
self._input_workers = input_lib.InputWorkers(
|
for hid in range(self.num_hosts)
|
||||||
input_device_map, worker_devices)
|
]
|
||||||
|
self._input_workers = input_lib.InputWorkers(
|
||||||
|
input_device_map, worker_devices)
|
||||||
|
else:
|
||||||
|
input_worker_devices = collections.OrderedDict()
|
||||||
|
for tpu_device in self._tpu_devices:
|
||||||
|
host_device = _get_host_for_device(tpu_device)
|
||||||
|
input_worker_devices.setdefault(host_device, [])
|
||||||
|
input_worker_devices[host_device].append(tpu_device)
|
||||||
|
self._input_workers = input_lib.InputWorkers(
|
||||||
|
self._device_map, tuple(input_worker_devices.items()))
|
||||||
|
|
||||||
# TODO(sourabhbajaj): Remove this once performance of running one step
|
# TODO(sourabhbajaj): Remove this once performance of running one step
|
||||||
# at a time is comparable to multiple steps.
|
# at a time is comparable to multiple steps.
|
||||||
@ -346,6 +357,113 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
|
||||||
def _experimental_run_steps_on_iterator(
|
def _experimental_run_steps_on_iterator(
|
||||||
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
||||||
|
if self._disable_training_loop_on_host:
|
||||||
|
impl = self._run_steps_on_iterator_with_device_loop
|
||||||
|
else:
|
||||||
|
impl = self._run_steps_on_iterator_with_host_loop
|
||||||
|
|
||||||
|
return impl(
|
||||||
|
fn=fn, multi_worker_iterator=multi_worker_iterator,
|
||||||
|
iterations=iterations, initial_loop_values=initial_loop_values)
|
||||||
|
|
||||||
|
def _run_steps_on_iterator_with_host_loop(
|
||||||
|
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
||||||
|
output_shapes = multi_worker_iterator.output_shapes
|
||||||
|
shapes = nest.flatten(output_shapes)
|
||||||
|
if any(not s.is_fully_defined() for s in shapes):
|
||||||
|
raise ValueError(
|
||||||
|
"TPU currently requires fully defined shapes. Either use "
|
||||||
|
"set_shape() on the input tensors or use "
|
||||||
|
"dataset.batch(..., drop_remainder=True).")
|
||||||
|
|
||||||
|
# Wrap `fn` for repeat.
|
||||||
|
if initial_loop_values is None:
|
||||||
|
initial_loop_values = {}
|
||||||
|
initial_loop_values = nest.flatten(initial_loop_values)
|
||||||
|
ctx = input_lib.MultiStepContext()
|
||||||
|
|
||||||
|
def run_fn(inputs):
|
||||||
|
"""Single step on the TPU device."""
|
||||||
|
fn_result = fn(ctx, inputs)
|
||||||
|
flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
|
||||||
|
if flat_last_step_outputs:
|
||||||
|
with ops.control_dependencies([fn_result]):
|
||||||
|
return [array_ops.identity(f) for f in flat_last_step_outputs]
|
||||||
|
else:
|
||||||
|
return fn_result
|
||||||
|
|
||||||
|
# We capture the control_flow_context at this point, before we run `fn`
|
||||||
|
# inside a while_loop and TPU replicate context. This is useful in cases
|
||||||
|
# where we might need to exit these contexts and get back to the outer
|
||||||
|
# context to do some things, for e.g. create an op which should be
|
||||||
|
# evaluated only once at the end of the loop on the host. One such usage
|
||||||
|
# is in creating metrics' value op.
|
||||||
|
self._outer_control_flow_context = (
|
||||||
|
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
|
||||||
|
|
||||||
|
def rewrite_fn(*args):
|
||||||
|
"""The rewritten step fn running on TPU."""
|
||||||
|
del args
|
||||||
|
|
||||||
|
per_replica_inputs = multi_worker_iterator.get_next()
|
||||||
|
replicate_inputs = []
|
||||||
|
for replica_id in range(self._num_replicas_in_sync):
|
||||||
|
select_replica = lambda x: values.select_replica(replica_id, x) # pylint: disable=cell-var-from-loop
|
||||||
|
replicate_inputs.append((nest.map_structure(
|
||||||
|
select_replica, per_replica_inputs),))
|
||||||
|
|
||||||
|
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
|
||||||
|
|
||||||
|
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
|
||||||
|
# will flatten it in this case. If run_fn has no tensor outputs,
|
||||||
|
# tpu.replicate returns a list of no_ops, we will keep the output as it
|
||||||
|
# is.
|
||||||
|
if isinstance(replicate_outputs[0], list):
|
||||||
|
replicate_outputs = nest.flatten(replicate_outputs)
|
||||||
|
|
||||||
|
return replicate_outputs
|
||||||
|
|
||||||
|
# TODO(sourabhbajaj): The input to while loop should be based on the
|
||||||
|
# output type of the step_fn
|
||||||
|
assert isinstance(initial_loop_values, list)
|
||||||
|
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
|
||||||
|
|
||||||
|
# Put the while loop op on host 0.
|
||||||
|
with ops.device(self.get_host_cpu_device(0)):
|
||||||
|
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
|
||||||
|
initial_loop_values)
|
||||||
|
|
||||||
|
del self._outer_control_flow_context
|
||||||
|
ctx.run_op = control_flow_ops.group(replicate_outputs)
|
||||||
|
|
||||||
|
if isinstance(replicate_outputs, list):
|
||||||
|
# Filter out any ops from the outputs, typically this would be the case
|
||||||
|
# when there were no tensor outputs.
|
||||||
|
last_step_tensor_outputs = [
|
||||||
|
x for x in replicate_outputs if not isinstance(x, ops.Operation)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Outputs are currently of the structure (flattened)
|
||||||
|
# [output0_device0, output1_device0, output2_device0,
|
||||||
|
# output0_device1, output1_device1, output2_device1,
|
||||||
|
# ...]
|
||||||
|
# Convert this to the following structure instead: (grouped by output)
|
||||||
|
# [[output0_device0, output0_device1],
|
||||||
|
# [output1_device0, output1_device1],
|
||||||
|
# [output2_device0, output2_device1]]
|
||||||
|
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
|
||||||
|
last_step_tensor_outputs = [
|
||||||
|
last_step_tensor_outputs[i::output_num] for i in range(output_num)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
# no tensors returned.
|
||||||
|
last_step_tensor_outputs = []
|
||||||
|
|
||||||
|
_set_last_step_outputs(ctx, last_step_tensor_outputs)
|
||||||
|
return ctx
|
||||||
|
|
||||||
|
def _run_steps_on_iterator_with_device_loop(
|
||||||
|
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
|
||||||
output_shapes = multi_worker_iterator.output_shapes
|
output_shapes = multi_worker_iterator.output_shapes
|
||||||
shapes = nest.flatten(output_shapes)
|
shapes = nest.flatten(output_shapes)
|
||||||
if any(not s.is_fully_defined() for s in shapes):
|
if any(not s.is_fully_defined() for s in shapes):
|
||||||
@ -393,95 +511,28 @@ class TPUExtended(distribute_lib.DistributionStrategyExtended):
|
|||||||
self._outer_control_flow_context = (
|
self._outer_control_flow_context = (
|
||||||
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
|
ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
replicate_inputs = [[]] * self._num_replicas_in_sync
|
||||||
if self._container_strategy()._disable_training_loop_on_host:
|
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
|
||||||
replicate_inputs = [[]] * self._num_replicas_in_sync
|
|
||||||
replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
|
|
||||||
else:
|
|
||||||
def rewrite_fn(*args):
|
|
||||||
"""The rewritten step fn running on TPU."""
|
|
||||||
del args
|
|
||||||
replicate_inputs = [[]] * self._num_replicas_in_sync
|
|
||||||
replicate_outputs = tpu.replicate(run_fn, replicate_inputs)
|
|
||||||
|
|
||||||
# If run_fn has tensor outputs, tpu.replicate returns a list of list. We
|
|
||||||
# will flatten it in this case. If run_fn has no tensor outputs,
|
|
||||||
# tpu.replicate returns a list of no_ops, we will keep the output as it
|
|
||||||
# is.
|
|
||||||
if isinstance(replicate_outputs[0], list):
|
|
||||||
replicate_outputs = nest.flatten(replicate_outputs)
|
|
||||||
|
|
||||||
return replicate_outputs
|
|
||||||
|
|
||||||
# TODO(sourabhbajaj): The input to while loop should be based on the
|
|
||||||
# output type of the step_fn
|
|
||||||
assert isinstance(initial_loop_values, list)
|
|
||||||
initial_loop_values = initial_loop_values * self._num_replicas_in_sync
|
|
||||||
|
|
||||||
# Put the while loop op on host 0.
|
|
||||||
with ops.device(self.get_host_cpu_device(0)):
|
|
||||||
replicate_outputs = training_loop.repeat(iterations, rewrite_fn,
|
|
||||||
initial_loop_values)
|
|
||||||
|
|
||||||
del self._outer_control_flow_context
|
del self._outer_control_flow_context
|
||||||
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
|
ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
|
||||||
|
|
||||||
if self._container_strategy()._disable_training_loop_on_host:
|
# Filter out any ops from the outputs, typically this would be the case
|
||||||
# Filter out any ops from the outputs, typically this would be the case
|
# when there were no tensor outputs.
|
||||||
# when there were no tensor outputs.
|
last_step_tensor_outputs = [x for x in replicate_outputs
|
||||||
last_step_tensor_outputs = [x for x in replicate_outputs
|
if not isinstance(x, ops.Operation)]
|
||||||
if not isinstance(x, ops.Operation)]
|
|
||||||
|
|
||||||
# Outputs are currently of the structure (grouped by device)
|
# Outputs are currently of the structure (grouped by device)
|
||||||
# [[output0_device0, output1_device0, output2_device0],
|
# [[output0_device0, output1_device0, output2_device0],
|
||||||
# [output0_device1, output1_device1, output2_device1]]
|
# [output0_device1, output1_device1, output2_device1]]
|
||||||
# Convert this to the following structure instead: (grouped by output)
|
# Convert this to the following structure instead: (grouped by output)
|
||||||
# [[output0_device0, output0_device1],
|
# [[output0_device0, output0_device1],
|
||||||
# [output1_device0, output1_device1],
|
# [output1_device0, output1_device1],
|
||||||
# [output2_device0, output2_device1]]
|
# [output2_device0, output2_device1]]
|
||||||
last_step_tensor_outputs = [list(x) for x in
|
last_step_tensor_outputs = [list(x) for x in
|
||||||
zip(*last_step_tensor_outputs)]
|
zip(*last_step_tensor_outputs)]
|
||||||
else:
|
|
||||||
if isinstance(replicate_outputs, list):
|
|
||||||
# Filter out any ops from the outputs, typically this would be the case
|
|
||||||
# when there were no tensor outputs.
|
|
||||||
last_step_tensor_outputs = [
|
|
||||||
x for x in replicate_outputs if not isinstance(x, ops.Operation)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Outputs are currently of the structure (flattened)
|
|
||||||
# [output0_device0, output1_device0, output2_device0,
|
|
||||||
# output0_device1, output1_device1, output2_device1,
|
|
||||||
# ...]
|
|
||||||
# Convert this to the following structure instead: (grouped by output)
|
|
||||||
# [[output0_device0, output0_device1],
|
|
||||||
# [output1_device0, output1_device1],
|
|
||||||
# [output2_device0, output2_device1]]
|
|
||||||
output_num = len(last_step_tensor_outputs) // self._num_replicas_in_sync
|
|
||||||
last_step_tensor_outputs = [
|
|
||||||
last_step_tensor_outputs[i::output_num] for i in range(output_num)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# no tensors returned.
|
|
||||||
last_step_tensor_outputs = []
|
|
||||||
|
|
||||||
# Convert replicate_outputs to the original dict structure of
|
|
||||||
# last_step_outputs.
|
|
||||||
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
|
||||||
ctx.last_step_outputs, last_step_tensor_outputs)
|
|
||||||
|
|
||||||
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
|
||||||
output = last_step_tensor_outputs_dict[name]
|
|
||||||
# For outputs that have already been reduced, take the first value
|
|
||||||
# from the list as each value should be the same. Else return the full
|
|
||||||
# list of values.
|
|
||||||
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
|
|
||||||
# value.
|
|
||||||
if reduce_op is not None:
|
|
||||||
# TODO(priyag): Should this return the element or a list with 1 element
|
|
||||||
last_step_tensor_outputs_dict[name] = output[0]
|
|
||||||
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
|
_set_last_step_outputs(ctx, last_step_tensor_outputs)
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
def _call_for_each_replica(self, fn, args, kwargs):
|
def _call_for_each_replica(self, fn, args, kwargs):
|
||||||
@ -744,3 +795,30 @@ class _TPUReplicaContext(distribute_lib.ReplicaContext):
|
|||||||
ds = self._strategy
|
ds = self._strategy
|
||||||
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
replica_id = tensor_util.constant_value(self._replica_id_in_sync_group)
|
||||||
return (ds.extended.worker_devices[replica_id],)
|
return (ds.extended.worker_devices[replica_id],)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_host_for_device(device):
|
||||||
|
spec = tf_device.DeviceSpec.from_string(device)
|
||||||
|
return tf_device.DeviceSpec(
|
||||||
|
job=spec.job, replica=spec.replica, task=spec.task,
|
||||||
|
device_type="CPU", device_index=0).to_string()
|
||||||
|
|
||||||
|
|
||||||
|
def _set_last_step_outputs(ctx, last_step_tensor_outputs):
|
||||||
|
"""Sets the last step outputs on the given context."""
|
||||||
|
# Convert replicate_outputs to the original dict structure of
|
||||||
|
# last_step_outputs.
|
||||||
|
last_step_tensor_outputs_dict = nest.pack_sequence_as(
|
||||||
|
ctx.last_step_outputs, last_step_tensor_outputs)
|
||||||
|
|
||||||
|
for name, reduce_op in ctx._last_step_outputs_reduce_ops.items(): # pylint: disable=protected-access
|
||||||
|
output = last_step_tensor_outputs_dict[name]
|
||||||
|
# For outputs that have already been reduced, take the first value
|
||||||
|
# from the list as each value should be the same. Else return the full
|
||||||
|
# list of values.
|
||||||
|
# TODO(josh11b): If reduce_op is NONE, we should return a PerReplica
|
||||||
|
# value.
|
||||||
|
if reduce_op is not None:
|
||||||
|
# TODO(priyag): Should this return the element or a list with 1 element
|
||||||
|
last_step_tensor_outputs_dict[name] = output[0]
|
||||||
|
ctx._set_last_step_outputs(last_step_tensor_outputs_dict) # pylint: disable=protected-access
|
||||||
|
Loading…
Reference in New Issue
Block a user