Support CompositeTensors in V2 single code path.
PiperOrigin-RevId: 259812771
This commit is contained in:
parent
8bac1116b7
commit
0f08941cfb
@ -304,7 +304,7 @@ def validate_per_replica_inputs(distribution_strategy, x):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
# Convert the inputs and targets into a list of PerReplica objects.
|
# Convert the inputs and targets into a list of PerReplica objects.
|
||||||
per_replica_list = nest.flatten(x)
|
per_replica_list = nest.flatten(x, expand_composites=True)
|
||||||
x_values_list = []
|
x_values_list = []
|
||||||
for x in per_replica_list:
|
for x in per_replica_list:
|
||||||
if not tensor_util.is_tensor(x):
|
if not tensor_util.is_tensor(x):
|
||||||
|
@ -27,6 +27,7 @@ import six
|
|||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework.ops import composite_tensor
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
from tensorflow.python.keras.utils import data_utils
|
from tensorflow.python.keras.utils import data_utils
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
@ -170,7 +171,16 @@ class TensorLikeDataAdapter(DataAdapter):
|
|||||||
if y is not None:
|
if y is not None:
|
||||||
flat_inputs += nest.flatten(y)
|
flat_inputs += nest.flatten(y)
|
||||||
|
|
||||||
return all(isinstance(v, (ops.Tensor, np.ndarray)) for v in flat_inputs)
|
def _is_tensor_or_composite(v):
|
||||||
|
if isinstance(v, (ops.Tensor, np.ndarray)):
|
||||||
|
return True
|
||||||
|
# Dataset inherits from CompositeTensor but shouldn't be handled here.
|
||||||
|
if (isinstance(v, composite_tensor.CompositeTensor) and
|
||||||
|
not isinstance(v, dataset_ops.DatasetV2)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return all(_is_tensor_or_composite(v) for v in flat_inputs)
|
||||||
|
|
||||||
def __init__(self, x, y=None, sample_weights=None, batch_size=None,
|
def __init__(self, x, y=None, sample_weights=None, batch_size=None,
|
||||||
shuffle=False, **kwargs):
|
shuffle=False, **kwargs):
|
||||||
|
@ -283,10 +283,9 @@ def train_on_batch(model,
|
|||||||
targets = training_utils.cast_if_floating_dtype(targets)
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
else:
|
else:
|
||||||
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
||||||
[ops.convert_to_tensor(val) for val in inputs], model)
|
inputs, model)
|
||||||
if targets:
|
if targets:
|
||||||
targets = training_utils.cast_if_floating_dtype(
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
[ops.convert_to_tensor(val) for val in targets])
|
|
||||||
if sample_weights:
|
if sample_weights:
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
||||||
@ -337,10 +336,9 @@ def test_on_batch(model,
|
|||||||
targets = training_utils.cast_if_floating_dtype(targets)
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
else:
|
else:
|
||||||
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
inputs = training_utils.cast_if_floating_to_model_input_dtypes(
|
||||||
[ops.convert_to_tensor(val) for val in inputs], model)
|
inputs, model)
|
||||||
if targets:
|
if targets:
|
||||||
targets = training_utils.cast_if_floating_dtype(
|
targets = training_utils.cast_if_floating_dtype(targets)
|
||||||
[ops.convert_to_tensor(val) for val in targets])
|
|
||||||
if sample_weights:
|
if sample_weights:
|
||||||
sample_weights = [
|
sample_weights = [
|
||||||
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
training_utils.cast_if_floating_dtype(ops.convert_to_tensor(val))
|
||||||
|
@ -1191,7 +1191,8 @@ def check_steps_argument(input_data, steps, steps_name):
|
|||||||
|
|
||||||
|
|
||||||
def cast_single_tensor(x, dtype=None):
|
def cast_single_tensor(x, dtype=None):
|
||||||
x = ops.convert_to_tensor(x)
|
if isinstance(x, np.ndarray):
|
||||||
|
x = ops.convert_to_tensor(x)
|
||||||
dtype = dtype or K.floatx()
|
dtype = dtype or K.floatx()
|
||||||
if x.dtype.is_floating:
|
if x.dtype.is_floating:
|
||||||
return math_ops.cast(x, dtype=dtype)
|
return math_ops.cast(x, dtype=dtype)
|
||||||
|
@ -29,6 +29,7 @@ import functools
|
|||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.framework.ops import composite_tensor
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
from tensorflow.python.keras.distribute import distributed_training_utils as dist_utils
|
||||||
from tensorflow.python.keras.engine import training_eager
|
from tensorflow.python.keras.engine import training_eager
|
||||||
@ -125,7 +126,8 @@ def _get_input_from_iterator(iterator):
|
|||||||
"""Get elements from the iterator and verify the input shape and type."""
|
"""Get elements from the iterator and verify the input shape and type."""
|
||||||
next_element = next(iterator)
|
next_element = next(iterator)
|
||||||
|
|
||||||
if tensor_util.is_tensor(next_element) or isinstance(next_element, dict):
|
if (tensor_util.is_tensor(next_element) or
|
||||||
|
isinstance(next_element, (dict, composite_tensor.CompositeTensor))):
|
||||||
next_element = [next_element]
|
next_element = [next_element]
|
||||||
if len(next_element) == 1:
|
if len(next_element) == 1:
|
||||||
x, = next_element
|
x, = next_element
|
||||||
|
@ -26,6 +26,7 @@ import scipy.sparse
|
|||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
|
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
@ -152,6 +153,17 @@ def get_model_from_layers_with_input(layers,
|
|||||||
raise ValueError("Unknown model type {}".format(model_type))
|
raise ValueError("Unknown model type {}".format(model_type))
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_mode_kwargs():
|
||||||
|
run_eagerly = testing_utils.should_run_eagerly()
|
||||||
|
# Certain things weren't supported correctly in the old path, therefore
|
||||||
|
# with these changes, some tests now only pass in the single code path in V2.
|
||||||
|
if run_eagerly or context.executing_eagerly():
|
||||||
|
run_distributed = True
|
||||||
|
else:
|
||||||
|
run_distributed = testing_utils.should_run_distributed()
|
||||||
|
return {"run_eagerly": run_eagerly, "run_distributed": run_distributed}
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
class CompositeTensorInternalTest(keras_parameterized.TestCase):
|
class CompositeTensorInternalTest(keras_parameterized.TestCase):
|
||||||
@ -194,11 +206,7 @@ class CompositeTensorInternalTest(keras_parameterized.TestCase):
|
|||||||
input_data = np.random.rand(1024, 1)
|
input_data = np.random.rand(1024, 1)
|
||||||
expected_data = np.concatenate((input_data * 3, input_data * .5), axis=-1)
|
expected_data = np.concatenate((input_data * 3, input_data * .5), axis=-1)
|
||||||
|
|
||||||
model.compile(
|
model.compile(loss="mse", optimizer="adam", **get_test_mode_kwargs())
|
||||||
loss="mse",
|
|
||||||
optimizer="adam",
|
|
||||||
run_eagerly=testing_utils.should_run_eagerly(),
|
|
||||||
run_distributed=testing_utils.should_run_distributed())
|
|
||||||
history = model.fit(input_data, expected_data, epochs=10, verbose=0)
|
history = model.fit(input_data, expected_data, epochs=10, verbose=0)
|
||||||
|
|
||||||
# If the model trained, the loss stored at history[0] should be different
|
# If the model trained, the loss stored at history[0] should be different
|
||||||
@ -284,26 +292,28 @@ def get_input_name(use_dict):
|
|||||||
return "test_input_name"
|
return "test_input_name"
|
||||||
|
|
||||||
|
|
||||||
def get_steps():
|
def get_kwargs(use_dataset, action="predict"):
|
||||||
# Determine the steps arg (if appropriate)
|
if use_dataset or not context.executing_eagerly():
|
||||||
if not testing_utils.should_run_eagerly():
|
if action == "fit":
|
||||||
# CompositeTensors in graph mode are symbolic and so require a steps arg.
|
return {"steps_per_epoch": 1}
|
||||||
return 1
|
return {"steps": 1}
|
||||||
else:
|
else:
|
||||||
return None
|
return {"batch_size": 2}
|
||||||
|
|
||||||
|
|
||||||
def prepare_inputs(data, use_dict, use_dataset, action, input_name):
|
def prepare_inputs(data, use_dict, use_dataset, action, input_name):
|
||||||
input_data, expected_output = data
|
input_data, expected_output = data
|
||||||
|
batch_size = input_data.shape[0]
|
||||||
# Prepare the input data.
|
# Prepare the input data.
|
||||||
if use_dict:
|
if use_dict:
|
||||||
input_data = {input_name: input_data}
|
input_data = {input_name: input_data}
|
||||||
if use_dataset:
|
if use_dataset:
|
||||||
if action == "predict":
|
if action == "predict":
|
||||||
input_data = dataset_ops.Dataset.from_tensors(input_data)
|
input_data = dataset_ops.DatasetV2.from_tensor_slices(input_data).batch(
|
||||||
|
batch_size)
|
||||||
else:
|
else:
|
||||||
input_data = dataset_ops.Dataset.from_tensors(
|
input_data = dataset_ops.DatasetV2.from_tensor_slices(
|
||||||
(input_data, expected_output))
|
(input_data, expected_output)).batch(batch_size)
|
||||||
expected_output = None
|
expected_output = None
|
||||||
return (input_data, expected_output)
|
return (input_data, expected_output)
|
||||||
|
|
||||||
@ -332,8 +342,12 @@ class SparseTensorInputTest(keras_parameterized.TestCase):
|
|||||||
shape=(1, None), sparse=True, name=input_name, dtype=dtypes.int32)
|
shape=(1, None), sparse=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
steps = get_steps()
|
optimizer="sgd",
|
||||||
|
loss="mse",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
**get_test_mode_kwargs())
|
||||||
|
kwargs = get_kwargs(use_dataset, action)
|
||||||
|
|
||||||
# Prepare the input data
|
# Prepare the input data
|
||||||
for data_element in data:
|
for data_element in data:
|
||||||
@ -342,15 +356,14 @@ class SparseTensorInputTest(keras_parameterized.TestCase):
|
|||||||
input_name)
|
input_name)
|
||||||
# Perform the action.
|
# Perform the action.
|
||||||
if action == "predict":
|
if action == "predict":
|
||||||
result = model.predict(input_data, steps=steps)
|
result = model.predict(input_data, **kwargs)
|
||||||
self.assertAllEqual(expected_output, result)
|
self.assertAllEqual(expected_output, result)
|
||||||
if action == "evaluate":
|
if action == "evaluate":
|
||||||
result = model.evaluate(input_data, expected_output, steps=steps)
|
result = model.evaluate(input_data, expected_output, **kwargs)
|
||||||
self.assertAllEqual(1.0, result[-1])
|
self.assertAllEqual(1.0, result[-1])
|
||||||
if action == "fit":
|
if action == "fit":
|
||||||
# TODO(momernick): What's the best way of validating that fit happened?
|
# TODO(momernick): What's the best way of validating that fit happened?
|
||||||
_ = model.fit(
|
_ = model.fit(input_data, expected_output, shuffle=False, **kwargs)
|
||||||
input_data, expected_output, shuffle=False, steps_per_epoch=steps)
|
|
||||||
|
|
||||||
|
|
||||||
@keras_parameterized.run_with_all_model_types
|
@keras_parameterized.run_with_all_model_types
|
||||||
@ -385,7 +398,11 @@ class ScipySparseTensorInputTest(keras_parameterized.TestCase,
|
|||||||
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int64)
|
model_input = input_layer.Input(shape=(3,), sparse=True, dtype=dtypes.int64)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss="mse",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
|
|
||||||
input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
|
input_data = scipy.sparse.coo_matrix(([1, 2, 3], ([0, 1, 1], [0, 0, 1])),
|
||||||
shape=[2, 3])
|
shape=[2, 3])
|
||||||
@ -443,7 +460,11 @@ class ScipySparseTensorInputTest(keras_parameterized.TestCase,
|
|||||||
shape=(3,), sparse=True, name=input_name, dtype=dtypes.int64)
|
shape=(3,), sparse=True, name=input_name, dtype=dtypes.int64)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss="mse",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
run_distributed=testing_utils.should_run_distributed())
|
||||||
|
|
||||||
input_data = {
|
input_data = {
|
||||||
input_name:
|
input_name:
|
||||||
@ -484,7 +505,11 @@ class RaggedTensorInputTest(keras_parameterized.TestCase,
|
|||||||
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
|
shape=(None, None), ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss="mse",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
**get_test_mode_kwargs())
|
||||||
|
|
||||||
# Prepare the input data
|
# Prepare the input data
|
||||||
for data_element in data:
|
for data_element in data:
|
||||||
@ -524,7 +549,11 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
|||||||
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss="mse",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
**get_test_mode_kwargs())
|
||||||
|
|
||||||
for data_element in data:
|
for data_element in data:
|
||||||
input_data, expected_output = prepare_inputs(
|
input_data, expected_output = prepare_inputs(
|
||||||
@ -549,11 +578,12 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
|||||||
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
# The input is a symbolic tensor in non-Eager modes, so 'steps' is required
|
loss="mse",
|
||||||
# for that case only.
|
metrics=["accuracy"],
|
||||||
steps = get_steps()
|
**get_test_mode_kwargs())
|
||||||
|
kwargs = get_kwargs(use_dataset)
|
||||||
|
|
||||||
for data_element in data:
|
for data_element in data:
|
||||||
input_data, expected_output = prepare_inputs(
|
input_data, expected_output = prepare_inputs(
|
||||||
@ -562,7 +592,7 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
|||||||
use_dataset,
|
use_dataset,
|
||||||
action="predict",
|
action="predict",
|
||||||
input_name=input_name)
|
input_name=input_name)
|
||||||
result = model.predict(input_data, steps=steps)
|
result = model.predict(input_data, **kwargs)
|
||||||
self.assertAllEqual(expected_output, result)
|
self.assertAllEqual(expected_output, result)
|
||||||
|
|
||||||
def test_ragged_tensor_input_with_wrong_ragged_rank_fails(
|
def test_ragged_tensor_input_with_wrong_ragged_rank_fails(
|
||||||
@ -577,7 +607,11 @@ class RaggedTensorInputValidationTest(keras_parameterized.TestCase,
|
|||||||
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
shape=input_shape, ragged=True, name=input_name, dtype=dtypes.int32)
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
model = get_model_from_layers_with_input(layers, model_input=model_input)
|
||||||
model.compile(optimizer="sgd", loss="mse", metrics=["accuracy"])
|
model.compile(
|
||||||
|
optimizer="sgd",
|
||||||
|
loss="mse",
|
||||||
|
metrics=["accuracy"],
|
||||||
|
**get_test_mode_kwargs())
|
||||||
|
|
||||||
# Define some input data with the wrong ragged rank
|
# Define some input data with the wrong ragged rank
|
||||||
for data_element in data:
|
for data_element in data:
|
||||||
@ -618,15 +652,9 @@ class SparseTensorInputValidationTest(keras_parameterized.TestCase):
|
|||||||
# Define some input data.
|
# Define some input data.
|
||||||
input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
input_data = sparse_tensor.SparseTensor([[0, 0, 0], [1, 0, 0], [1, 0, 1]],
|
||||||
[1, 2, 3], [2, 1, 3])
|
[1, 2, 3], [2, 1, 3])
|
||||||
if not testing_utils.should_run_eagerly():
|
kwargs = get_kwargs(use_dataset=False)
|
||||||
# This ragged tensor is actually a standard tensor (as it has no ragged
|
|
||||||
# dimensions). Because of this, graph mode models will expect a steps
|
|
||||||
# arg to be passed (as SparseTensors in graph mode are symbolic).
|
|
||||||
steps = 1
|
|
||||||
else:
|
|
||||||
steps = None
|
|
||||||
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
with self.assertRaisesRegex(ValueError, ".*got array with shape.*"):
|
||||||
_ = model.predict(input_data, steps=steps)
|
_ = model.predict(input_data, **kwargs)
|
||||||
|
|
||||||
def test_ragged_tensor_input_with_wrong_value_shape(self):
|
def test_ragged_tensor_input_with_wrong_value_shape(self):
|
||||||
# Create a model that accepts a ragged input and converts it to dense.
|
# Create a model that accepts a ragged input and converts it to dense.
|
||||||
@ -652,14 +680,14 @@ class UndefinedCompositeTensorInputsTest(keras_parameterized.TestCase):
|
|||||||
# back to a dense tensor.
|
# back to a dense tensor.
|
||||||
layers = [ToDense(default_value=-1)]
|
layers = [ToDense(default_value=-1)]
|
||||||
model = testing_utils.get_model_from_layers(layers)
|
model = testing_utils.get_model_from_layers(layers)
|
||||||
steps = get_steps()
|
|
||||||
|
|
||||||
# Define some input data.
|
# Define some input data.
|
||||||
input_data = sparse_tensor.SparseTensor([[0, 0], [1, 0], [1, 1]], [1, 2, 3],
|
input_data = sparse_tensor.SparseTensor([[0, 0], [1, 0], [1, 1]], [1, 2, 3],
|
||||||
[2, 3])
|
[2, 3])
|
||||||
|
kwargs = get_kwargs(False)
|
||||||
with self.assertRaisesRegex(
|
with self.assertRaisesRegex(
|
||||||
ValueError, ".*All SparseTensor and RaggedTensor inputs .*"):
|
ValueError, ".*All SparseTensor and RaggedTensor inputs .*"):
|
||||||
_ = model.predict(input_data, steps=steps)
|
_ = model.predict(input_data, **kwargs)
|
||||||
|
|
||||||
def test_subclass_implicit_sparse_scipy_inputs_fails(self):
|
def test_subclass_implicit_sparse_scipy_inputs_fails(self):
|
||||||
# Create a model that accepts a sparse input and converts the sparse tensor
|
# Create a model that accepts a sparse input and converts the sparse tensor
|
||||||
|
Loading…
Reference in New Issue
Block a user