class. This removes the _as_variant_tensor() method from the DatasetV2 class (the version going to be used in TF 2.0) and replaces it with a _variant_tensor property that returns the variant_tensor representing the dataset. Also the __init__() method of DatasetV2 now takes a variant_tensor input. For the DatasetV1 class (current API), we run the _as_variant_tensor() method in the __init__() method, so classes subclassing DatasetV1 should make their super() calls in the end. Another implication is for Estimator code. The estimator input_fn's are supposed to be self contained and can't have ops from other graphs (like default graphs) in them. Earlier on because we didn't add anything to the graph while creating the Dataset object, this wasn't an issue but now this is a problem and the dataset creation code now needs to move into the input_fns. A few other changes were required to make this happen 1. The make_one_shot_iterator code captures inputs by value and since now inputs to a dataset could be other datasets which are not capturable, we use the whitelisting mechanism in functions to recreate these ops. 2. The distribution strategies multi-worker code relied on dataset kernel re-creation on different devices while we created the iterator. In the new world, with the kernels already created, we now have to "clone" the dataset on different devices. 3. Auto sharding in distribution strategies is broken with this CL. For now, this CL disables it, but we can subsequently fix it using some of the cloning logic done for 2). 4. AsGraphDefInternal for functions that capture inputs that are datasets now need to be handled differently as DT_VARIANT tensors representing datasets are not serializable. PiperOrigin-RevId: 226115500
57 lines
2.1 KiB
Python
57 lines
2.1 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Helpers to traverse the Dataset dependency structure."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from six.moves import queue as Queue # pylint: disable=redefined-builtin
|
|
|
|
from tensorflow.python.framework import dtypes
|
|
|
|
|
|
def obtain_all_variant_tensor_ops(dataset):
|
|
"""Given an input dataset, finds all dataset ops used for construction.
|
|
|
|
A series of transformations would have created this dataset with each
|
|
transformation including zero or more Dataset ops, each producing a dataset
|
|
variant tensor. This method outputs all of them.
|
|
|
|
Args:
|
|
dataset: Dataset to find variant tensors for.
|
|
|
|
Returns:
|
|
A list of variant_tensor producing dataset ops used to construct this
|
|
dataset.
|
|
"""
|
|
all_variant_tensor_ops = []
|
|
bfs_q = Queue.Queue()
|
|
bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access
|
|
visited = []
|
|
while not bfs_q.empty():
|
|
op = bfs_q.get()
|
|
visited.append(op)
|
|
# We look for all ops that produce variant tensors as output. This is a bit
|
|
# of overkill but the other dataset _inputs() traversal strategies can't
|
|
# cover the case of function inputs that capture dataset variants.
|
|
# TODO(b/120873778): Make this more efficient.
|
|
if op.outputs[0].dtype == dtypes.variant:
|
|
all_variant_tensor_ops.append(op)
|
|
for i in op.inputs:
|
|
input_op = i.op
|
|
if input_op not in visited:
|
|
bfs_q.put(input_op)
|
|
return all_variant_tensor_ops
|