Add a type-spec conversion for the HDF5Matrix class.
PiperOrigin-RevId: 254091284
This commit is contained in:
parent
cf9bdb260f
commit
3877e8ee47
@ -22,9 +22,10 @@ import collections
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
try:
|
||||
import h5py
|
||||
except ImportError:
|
||||
@ -149,6 +150,25 @@ class HDF5Matrix(object):
|
||||
"""
|
||||
return np.prod(self.shape)
|
||||
|
||||
@staticmethod
|
||||
def _to_type_spec(value):
|
||||
"""Gets the Tensorflow TypeSpec corresponding to the passed dataset.
|
||||
|
||||
Args:
|
||||
value: A HDF5Matrix object.
|
||||
|
||||
Returns:
|
||||
A tf.TensorSpec.
|
||||
"""
|
||||
if not isinstance(value, HDF5Matrix):
|
||||
raise TypeError('Expected value to be a HDF5Matrix, but saw: {}'.format(
|
||||
type(value)))
|
||||
return tensor_spec.TensorSpec(shape=value.shape, dtype=value.dtype)
|
||||
|
||||
|
||||
type_spec.register_type_spec_from_value_converter(HDF5Matrix,
|
||||
HDF5Matrix._to_type_spec) # pylint: disable=protected-access
|
||||
|
||||
|
||||
def ask_to_proceed_with_overwrite(filepath):
|
||||
"""Produces a prompt asking about overwriting a file.
|
||||
|
Loading…
Reference in New Issue
Block a user