Add a type-spec conversion for the HDF5Matrix class.

PiperOrigin-RevId: 254091284
This commit is contained in:
A. Unique TensorFlower 2019-06-19 16:09:06 -07:00 committed by TensorFlower Gardener
parent cf9bdb260f
commit 3877e8ee47

View File

@ -22,9 +22,10 @@ import collections
import numpy as np
import six
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.util.tf_export import keras_export
try:
import h5py
except ImportError:
@ -149,6 +150,25 @@ class HDF5Matrix(object):
"""
return np.prod(self.shape)
@staticmethod
def _to_type_spec(value):
"""Gets the Tensorflow TypeSpec corresponding to the passed dataset.
Args:
value: A HDF5Matrix object.
Returns:
A tf.TensorSpec.
"""
if not isinstance(value, HDF5Matrix):
raise TypeError('Expected value to be a HDF5Matrix, but saw: {}'.format(
type(value)))
return tensor_spec.TensorSpec(shape=value.shape, dtype=value.dtype)
type_spec.register_type_spec_from_value_converter(HDF5Matrix,
HDF5Matrix._to_type_spec) # pylint: disable=protected-access
def ask_to_proceed_with_overwrite(filepath):
"""Produces a prompt asking about overwriting a file.