diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index fc7b8461706..2e6e1190488 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -559,17 +559,6 @@ class FuncGraph(ops.Graph): # makes sure that any tensor needed by a custom_gradient is correctly # captured. - # TODO(b/134097853): figure out a better way to check distributed variables - if hasattr(tensor, "_distribute_strategy") and hasattr(tensor, "_values"): - # This checks if the 'tensor' is a DistributedVariable. When it is a - # DistributedVariable, we do not want to check its "graph" attr as the - # following if branch does, because "graph" is not an attr for the - # container DistributedVariable object, and the underlying components may - # not have been initialized yet. - # The reason we do not use isinstance() is due to cyclic dependency issue. - if name is None: - name = str("distributed_variable") - return self._capture_helper(tensor, name) if (getattr(tensor, "graph", None) is not self and hasattr(self, "_forward_func_graph") and isinstance(self._forward_func_graph, FuncGraph)): @@ -605,6 +594,12 @@ class FuncGraph(ops.Graph): lambda x: [x]) return captured_tensor + def capture_distributed_variable(self, variable, placeholder): + """Add given distributed variable to captures with given placeholder.""" + self.captures[variable] = placeholder + tape.record_operation("captured_value", [placeholder], [variable], + lambda x: [x]) + @property def external_captures(self): """External tensors captured by this function.""" diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py index bf62c6b8530..f2994472aa1 100644 --- a/tensorflow/python/saved_model/load.py +++ b/tensorflow/python/saved_model/load.py @@ -176,22 +176,26 @@ class Loader(object): if bound_inputs: for bound_input, internal_capture in zip( bound_inputs, concrete_function.inputs[-len(bound_inputs):]): - concrete_function.graph.captures[bound_input] = internal_capture - if internal_capture.dtype == dtypes.resource: - if resource_variable_ops.is_resource_variable(bound_input): - try: - handle = bound_input.handle - except ValueError: - # For mirrored variables we'll copy handle data for components - # as they get captured. - pass + if ds_values.is_distributed_variable(bound_input): + concrete_function.graph.capture_distributed_variable( + bound_input, internal_capture) + else: + concrete_function.graph.captures[bound_input] = internal_capture + if internal_capture.dtype == dtypes.resource: + if resource_variable_ops.is_resource_variable(bound_input): + try: + handle = bound_input.handle + except ValueError: + # For mirrored variables we'll copy handle data for components + # as they get captured. + pass + else: + custom_gradient.copy_handle_data(handle, internal_capture) else: - custom_gradient.copy_handle_data(handle, internal_capture) - else: - custom_gradient.copy_handle_data(bound_input, internal_capture) - # Setting "captures" first means "capture" won't create a new - # placeholder for this input. - concrete_function.graph.capture(bound_input) + custom_gradient.copy_handle_data(bound_input, internal_capture) + # Setting "captures" first means "capture" won't create a new + # placeholder for this input. + concrete_function.graph.capture(bound_input) def _get_tensor_from_node(self, node_id): """Resolves a node id into a tensor to be captured for a function."""