Sort the keras predict result based on the batch index in v2 code path.
PiperOrigin-RevId: 279158153 Change-Id: I76638fa6613cb0de1f8f6494c960debbee3055ce
This commit is contained in:
parent
37cdb806bc
commit
8e0f8c4d8c
@ -1612,6 +1612,23 @@ tf_py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "training_v2_utils_test",
|
||||||
|
size = "medium",
|
||||||
|
srcs = ["engine/training_v2_utils_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":keras",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python/distribute:strategy_combinations",
|
||||||
|
],
|
||||||
|
tags = [
|
||||||
|
"no_oss", # TODO(b/135021748) reenable
|
||||||
|
"notsan",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "model_subclassing_test_util",
|
name = "model_subclassing_test_util",
|
||||||
srcs = ["model_subclassing_test_util.py"],
|
srcs = ["model_subclassing_test_util.py"],
|
||||||
|
@ -164,7 +164,8 @@ def run_one_epoch(model,
|
|||||||
batch_logs['size'] = data_batch_size
|
batch_logs['size'] = data_batch_size
|
||||||
current_batch_size = data_batch_size
|
current_batch_size = data_batch_size
|
||||||
else:
|
else:
|
||||||
batch_outs = _aggregate_predict_results(strategy, batch_outs, model)
|
batch_outs = training_v2_utils._aggregate_predict_results(
|
||||||
|
strategy, batch_outs, model)
|
||||||
|
|
||||||
if step == 0:
|
if step == 0:
|
||||||
aggregator.create(batch_outs)
|
aggregator.create(batch_outs)
|
||||||
@ -435,6 +436,8 @@ class Loop(training_utils.TrainingLoop):
|
|||||||
|
|
||||||
# tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
|
# tf.print('{} on {} steps.'.format(ModeKeys.TRAIN, steps_per_epoch))
|
||||||
training_context = TrainingContext()
|
training_context = TrainingContext()
|
||||||
|
if mode == ModeKeys.PREDICT:
|
||||||
|
dataset = training_v2_utils._add_batch_index_to_element(dataset)
|
||||||
dataset = strategy.experimental_distribute_dataset(dataset)
|
dataset = strategy.experimental_distribute_dataset(dataset)
|
||||||
|
|
||||||
execution_function = training_v2_utils._get_or_make_execution_function(
|
execution_function = training_v2_utils._get_or_make_execution_function(
|
||||||
@ -708,18 +711,6 @@ def _get_total_number_of_samples(adapter):
|
|||||||
return total_sample
|
return total_sample
|
||||||
|
|
||||||
|
|
||||||
def _aggregate_predict_results(strategy, batch_outs, model):
|
|
||||||
if not isinstance(batch_outs, list):
|
|
||||||
batch_outs = [batch_outs]
|
|
||||||
total_batch_outs = []
|
|
||||||
for i in range(len(model.outputs)):
|
|
||||||
num_replicas = strategy.num_replicas_in_sync
|
|
||||||
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
|
||||||
total_batch_outs.append(
|
|
||||||
dist_utils.concat_along_batch_dimension(nest.flatten(nested_outs)))
|
|
||||||
return total_batch_outs
|
|
||||||
|
|
||||||
|
|
||||||
def _print_train_info(total_samples, steps, val_total_samples, val_steps):
|
def _print_train_info(total_samples, steps, val_total_samples, val_steps):
|
||||||
increment = 'samples' if total_samples else 'steps'
|
increment = 'samples' if total_samples else 'steps'
|
||||||
conjunction = 'on' if total_samples else 'for'
|
conjunction = 'on' if total_samples else 'for'
|
||||||
|
@ -26,9 +26,13 @@ from __future__ import print_function
|
|||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.framework.ops import composite_tensor
|
from tensorflow.python.framework.ops import composite_tensor
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
@ -38,6 +42,7 @@ from tensorflow.python.keras.engine import training_utils
|
|||||||
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
@ -62,14 +67,14 @@ def _make_execution_function(model, mode):
|
|||||||
|
|
||||||
def distributed_function(input_iterator):
|
def distributed_function(input_iterator):
|
||||||
"""A single step of the distributed execution across replicas."""
|
"""A single step of the distributed execution across replicas."""
|
||||||
x, y, sample_weights = _prepare_feed_values(model, input_iterator, mode)
|
args = _prepare_feed_values(model, input_iterator, mode)
|
||||||
# Call `Model.{train,test,predict}_on_batch` on every replica passing
|
# Call `Model.{train,test,predict}_on_batch` on every replica passing
|
||||||
# PerReplicas as arguments. On every replica inside this call, each
|
# PerReplicas as arguments. On every replica inside this call, each
|
||||||
# PerReplica object will return the value for that replica. The outputs
|
# PerReplica object will return the value for that replica. The outputs
|
||||||
# are PerReplicas too.
|
# are PerReplicas too.
|
||||||
strategy = distribution_strategy_context.get_strategy()
|
strategy = distribution_strategy_context.get_strategy()
|
||||||
outputs = strategy.experimental_run_v2(
|
outputs = strategy.experimental_run_v2(
|
||||||
per_replica_function, args=(x, y, sample_weights))
|
per_replica_function, args=args)
|
||||||
# Out of PerReplica outputs reduce or pick values to return.
|
# Out of PerReplica outputs reduce or pick values to return.
|
||||||
all_outputs = dist_utils.unwrap_output_dict(
|
all_outputs = dist_utils.unwrap_output_dict(
|
||||||
strategy, outputs, mode)
|
strategy, outputs, mode)
|
||||||
@ -108,7 +113,11 @@ def _prepare_feed_values(model, inputs, mode):
|
|||||||
(tuple, targets, sample_weights) may be a python list. Single values
|
(tuple, targets, sample_weights) may be a python list. Single values
|
||||||
for inputs will always be wrapped in lists.
|
for inputs will always be wrapped in lists.
|
||||||
"""
|
"""
|
||||||
inputs, targets, sample_weights = _get_input_from_iterator(inputs)
|
# For predict, we need to extract the manually added batch_index first.
|
||||||
|
with_batch_index = mode == ModeKeys.PREDICT
|
||||||
|
|
||||||
|
inputs, targets, sample_weights, batch_index = _get_input_from_iterator(
|
||||||
|
inputs, with_batch_index)
|
||||||
|
|
||||||
# When the inputs are dict, then we want to flatten it in the same order as
|
# When the inputs are dict, then we want to flatten it in the same order as
|
||||||
# the input layers, such that the data are fed into the input layers in the
|
# the input layers, such that the data are fed into the input layers in the
|
||||||
@ -123,12 +132,18 @@ def _prepare_feed_values(model, inputs, mode):
|
|||||||
targets = []
|
targets = []
|
||||||
|
|
||||||
ins = [inputs, targets, sample_weights]
|
ins = [inputs, targets, sample_weights]
|
||||||
|
if batch_index is not None:
|
||||||
|
ins.append(batch_index)
|
||||||
return tuple(ins)
|
return tuple(ins)
|
||||||
|
|
||||||
|
|
||||||
def _get_input_from_iterator(iterator):
|
def _get_input_from_iterator(iterator, with_batch_index=False):
|
||||||
"""Get elements from the iterator and verify the input shape and type."""
|
"""Get elements from the iterator and verify the input shape and type."""
|
||||||
next_element = next(iterator)
|
next_element = next(iterator)
|
||||||
|
if with_batch_index:
|
||||||
|
batch_index, next_element = next_element
|
||||||
|
else:
|
||||||
|
batch_index = None
|
||||||
|
|
||||||
if (tensor_util.is_tensor(next_element) or
|
if (tensor_util.is_tensor(next_element) or
|
||||||
isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
|
isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
|
||||||
@ -146,7 +161,7 @@ def _get_input_from_iterator(iterator):
|
|||||||
# Validate that all the elements in x and y are of the same type and shape.
|
# Validate that all the elements in x and y are of the same type and shape.
|
||||||
dist_utils.validate_distributed_dataset_inputs(
|
dist_utils.validate_distributed_dataset_inputs(
|
||||||
distribution_strategy_context.get_strategy(), x, y, sample_weights)
|
distribution_strategy_context.get_strategy(), x, y, sample_weights)
|
||||||
return x, y, sample_weights
|
return x, y, sample_weights, batch_index
|
||||||
|
|
||||||
|
|
||||||
def _make_replica_execution_function(model, mode):
|
def _make_replica_execution_function(model, mode):
|
||||||
@ -156,9 +171,11 @@ def _make_replica_execution_function(model, mode):
|
|||||||
elif mode == ModeKeys.TEST:
|
elif mode == ModeKeys.TEST:
|
||||||
func = functools.partial(test_on_batch, model)
|
func = functools.partial(test_on_batch, model)
|
||||||
else:
|
else:
|
||||||
def _predict_on_batch(x, y=None, sample_weights=None):
|
def _predict_on_batch(x, y=None, sample_weights=None, batch_index=None):
|
||||||
del y, sample_weights
|
del y, sample_weights
|
||||||
return predict_on_batch(model, x)
|
# Note that the x and batch_index is already per-replica value.
|
||||||
|
result = predict_on_batch(model, x)
|
||||||
|
return (batch_index, result)
|
||||||
|
|
||||||
func = _predict_on_batch
|
func = _predict_on_batch
|
||||||
|
|
||||||
@ -170,6 +187,105 @@ def _make_replica_execution_function(model, mode):
|
|||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def _aggregate_predict_results(strategy, batch_outs, model):
|
||||||
|
"""Aggregate the prediction result from each replica."""
|
||||||
|
num_replicas = strategy.num_replicas_in_sync
|
||||||
|
num_outputs = len(model.outputs)
|
||||||
|
|
||||||
|
if not isinstance(batch_outs, list):
|
||||||
|
batch_outs = [batch_outs]
|
||||||
|
|
||||||
|
# batch_outs is in following structure:
|
||||||
|
# [
|
||||||
|
# replica_1_batch_index, replica_2_batch_index, ...., replica_x_batch_index,
|
||||||
|
# replica_1_output_1, replica_2_output_1, ...., replica_x_output_1,
|
||||||
|
# ......
|
||||||
|
# replica_1_output_y, replica_2_output_y, ...., replica_x_output_y,
|
||||||
|
# ]
|
||||||
|
batch_index, batch_outs = batch_outs[:num_replicas], batch_outs[num_replicas:]
|
||||||
|
batch_index = dist_utils.concat_along_batch_dimension(batch_index)
|
||||||
|
# Reorder the batch_index for it to do proper gather. Eg, if the original
|
||||||
|
# index is [0, 2, 4, 6, 1, 3, 5, 7], then the index for gather should be
|
||||||
|
# [0, 4, 1, 5, 2, 6, 3, 7].
|
||||||
|
batch_index = np.argsort(batch_index)
|
||||||
|
# Only need to gather if the batch index is not sorted.
|
||||||
|
need_batch_index_gather = np.any(np.diff(batch_index) < 0)
|
||||||
|
|
||||||
|
total_batch_outs = []
|
||||||
|
for i in range(num_outputs):
|
||||||
|
nested_outs = batch_outs[i * num_replicas:i * num_replicas + num_replicas]
|
||||||
|
per_output_result = dist_utils.concat_along_batch_dimension(
|
||||||
|
nest.flatten(nested_outs))
|
||||||
|
|
||||||
|
if need_batch_index_gather:
|
||||||
|
if _get_batch_size(per_output_result).numpy() == len(batch_index):
|
||||||
|
# Skip the gather if the output has a different batch size than the
|
||||||
|
# batch_index. There will be some error handling in upper layer.
|
||||||
|
per_output_result = _gather_result_by_index(per_output_result,
|
||||||
|
batch_index)
|
||||||
|
total_batch_outs.append(per_output_result)
|
||||||
|
return total_batch_outs
|
||||||
|
|
||||||
|
|
||||||
|
def _gather_result_by_index(input_tensor, batch_index):
|
||||||
|
"""Handle the data element gather for different type of tensor."""
|
||||||
|
if isinstance(input_tensor, sparse_tensor.SparseTensor):
|
||||||
|
# For sparse tensor, both the index and value component should be gathered.
|
||||||
|
return sparse_tensor.SparseTensor(
|
||||||
|
indices=array_ops.gather_v2(input_tensor.indices, batch_index),
|
||||||
|
values=array_ops.gather_v2(input_tensor.values, batch_index),
|
||||||
|
dense_shape=input_tensor.dense_shape
|
||||||
|
)
|
||||||
|
# For both ragged tensor or eager tensor or np array, tf.gather should do the
|
||||||
|
# correct thing.
|
||||||
|
elif isinstance(input_tensor, ragged_tensor.RaggedTensor):
|
||||||
|
return array_ops.gather_v2(input_tensor, batch_index)
|
||||||
|
elif isinstance(input_tensor, (ops.EagerTensor, np.ndarray)):
|
||||||
|
return array_ops.gather_v2(input_tensor, batch_index).numpy()
|
||||||
|
else:
|
||||||
|
raise ValueError('Unexpected type {} encountered when gathering '
|
||||||
|
'batch slices.'.format(input_tensor))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_size(inputs):
|
||||||
|
first_inputs = nest.flatten(inputs)[0]
|
||||||
|
if isinstance(first_inputs, ragged_tensor.RaggedTensor):
|
||||||
|
return first_inputs.bounding_shape()[0]
|
||||||
|
else:
|
||||||
|
return array_ops.shape(first_inputs)[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _add_batch_index_to_element(dataset):
|
||||||
|
"""Adding a new batch index field to the every element in the batch.
|
||||||
|
|
||||||
|
This is need in the model.predict() when running with multi-worker
|
||||||
|
distribution strategy. When sharding/distributing a dataset, the continuity of
|
||||||
|
the sharded dataset can't be easily ensured without performance sacrifice. It
|
||||||
|
is fine to train and eval with the reordered data, but not for prediction. To
|
||||||
|
solve this issue, Keras will add a batch index to each of the element in the
|
||||||
|
dataset, which will then pass to pre-replica execution function. The real
|
||||||
|
execution function will remove it before feeding the input to the model, and
|
||||||
|
pre-replica function will then zip the index with the result. Finally Keras
|
||||||
|
will sort the batch result based on the added batch-index field, remove it and
|
||||||
|
return the sorted result.
|
||||||
|
|
||||||
|
Note that we didn't add single index to the per-replica batch, but to each of
|
||||||
|
the element in the batch, since we can't ensure the data in pre-replica is
|
||||||
|
continuous. Eg: model with 2 replica and predict with 4 elements per batch
|
||||||
|
like [1, 2, 3, 4], it is possible to shard as [1, 2], [3, 4],
|
||||||
|
or [1, 3], [2, 4].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: a dataset that is created by any of the data_adapter, with the
|
||||||
|
element structure as (x, y, sample_weights).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
a new dataset, with the element shape as
|
||||||
|
(batch_index, (x, y, sample_weights)).
|
||||||
|
"""
|
||||||
|
return dataset.map(lambda *inp: (math_ops.range(_get_batch_size(inp)), inp))
|
||||||
|
|
||||||
|
|
||||||
def train_on_batch(
|
def train_on_batch(
|
||||||
model,
|
model,
|
||||||
x,
|
x,
|
||||||
|
143
tensorflow/python/keras/engine/training_v2_utils_test.py
Normal file
143
tensorflow/python/keras/engine/training_v2_utils_test.py
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tensorflow.python.keras.engine.training_v2_utils."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.distribute import mirrored_strategy
|
||||||
|
from tensorflow.python.distribute import strategy_combinations
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
|
from tensorflow.python.framework import combinations
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||||
|
from tensorflow.python.keras.engine import training_v2_utils
|
||||||
|
from tensorflow.python.keras.utils.mode_keys import ModeKeys
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class AggregatePredictResultsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(AggregatePredictResultsTest, self).setUp()
|
||||||
|
strategy_combinations.set_virtual_cpus_to_at_least(3)
|
||||||
|
self.num_replica = 3
|
||||||
|
self.batch_size = 16
|
||||||
|
self.dense_shape = (2, 3)
|
||||||
|
self.total_sample = 2 * self.batch_size
|
||||||
|
|
||||||
|
mock_model = collections.namedtuple('Model', ['outputs'])
|
||||||
|
self.mock_model = mock_model([1])
|
||||||
|
|
||||||
|
strategy = mirrored_strategy.MirroredStrategy(
|
||||||
|
['/cpu:0', '/cpu:1', '/cpu:2'])
|
||||||
|
|
||||||
|
execution_function = lambda *inp: inp
|
||||||
|
@def_function.function
|
||||||
|
def predict_loop(batch):
|
||||||
|
batch_result = strategy.experimental_run_v2(execution_function, batch)
|
||||||
|
batch_result = dist_utils.unwrap_output_dict(
|
||||||
|
strategy, batch_result, ModeKeys.PREDICT)
|
||||||
|
# swap the order of replica 1 and 2, to mimic random order.
|
||||||
|
batch_result[2], batch_result[1] = batch_result[1], batch_result[2]
|
||||||
|
batch_result[5], batch_result[4] = batch_result[4], batch_result[5]
|
||||||
|
return batch_result
|
||||||
|
|
||||||
|
self.strategy = strategy
|
||||||
|
self.predict_loop = predict_loop
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
|
||||||
|
mode='eager'))
|
||||||
|
def test_aggregate_predict_results_dense(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(self.total_sample)
|
||||||
|
def dense_map_fn(i):
|
||||||
|
# Mimic what we do for adding batch index
|
||||||
|
return i, array_ops.fill(self.dense_shape, i)
|
||||||
|
dense_dataset = dataset.map(dense_map_fn).batch(self.batch_size)
|
||||||
|
distributed_data = self.strategy.experimental_distribute_dataset(
|
||||||
|
dense_dataset)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for batch in distributed_data:
|
||||||
|
batch_result = self.predict_loop(batch)
|
||||||
|
final_result = training_v2_utils._aggregate_predict_results(
|
||||||
|
self.strategy, batch_result, self.mock_model)
|
||||||
|
|
||||||
|
# Make sure the dense result is in a sorted order.
|
||||||
|
expected_result = np.arange(
|
||||||
|
start=start, stop=start+self.batch_size).reshape((-1, 1))
|
||||||
|
expected_result = np.tile(expected_result, 6).reshape(
|
||||||
|
(-1,) + self.dense_shape)
|
||||||
|
self.assertAllClose(final_result[0], expected_result)
|
||||||
|
start += self.batch_size
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
|
||||||
|
mode='eager'))
|
||||||
|
def test_aggregate_predict_results_sparse(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(self.total_sample)
|
||||||
|
def sparse_map_fn(i):
|
||||||
|
return i, sparse_tensor.SparseTensor(
|
||||||
|
indices=[(0, 0)],
|
||||||
|
values=[i],
|
||||||
|
dense_shape=self.dense_shape)
|
||||||
|
sparse_dataset = dataset.map(sparse_map_fn).batch(self.batch_size)
|
||||||
|
distributed_data = self.strategy.experimental_distribute_dataset(
|
||||||
|
sparse_dataset)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for batch in distributed_data:
|
||||||
|
batch_result = self.predict_loop(batch)
|
||||||
|
final_result = training_v2_utils._aggregate_predict_results(
|
||||||
|
self.strategy, batch_result, self.mock_model)
|
||||||
|
|
||||||
|
# Make sure the dense result is in a sorted order.
|
||||||
|
expected_values = np.arange(start=start, stop=start+self.batch_size)
|
||||||
|
self.assertAllClose(final_result[0].values, expected_values)
|
||||||
|
start += self.batch_size
|
||||||
|
|
||||||
|
@combinations.generate(combinations.combine(tf_api_version=[1, 2],
|
||||||
|
mode='eager'))
|
||||||
|
def test_aggregate_predict_results_ragged(self):
|
||||||
|
dataset = dataset_ops.Dataset.range(self.total_sample)
|
||||||
|
def ragged_map_fn(i):
|
||||||
|
return i, ragged_factory_ops.constant([[0], [], []], dtype=np.int64) + i
|
||||||
|
ragged_dataset = dataset.map(ragged_map_fn).batch(self.batch_size)
|
||||||
|
distributed_data = self.strategy.experimental_distribute_dataset(
|
||||||
|
ragged_dataset)
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
for batch in distributed_data:
|
||||||
|
batch_result = self.predict_loop(batch)
|
||||||
|
final_result = training_v2_utils._aggregate_predict_results(
|
||||||
|
self.strategy, batch_result, self.mock_model)
|
||||||
|
|
||||||
|
# Make sure the dense result is in a sorted order.
|
||||||
|
expected_values = np.arange(start=start, stop=start+self.batch_size)
|
||||||
|
self.assertAllClose(final_result[0].flat_values, expected_values)
|
||||||
|
start += self.batch_size
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
Loading…
x
Reference in New Issue
Block a user