From 3877e8ee4777e19c2166788a885d5e5bbe8088f7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 19 Jun 2019 16:09:06 -0700 Subject: [PATCH] Add a type-spec conversion for the HDF5Matrix class. PiperOrigin-RevId: 254091284 --- tensorflow/python/keras/utils/io_utils.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/tensorflow/python/keras/utils/io_utils.py b/tensorflow/python/keras/utils/io_utils.py index 6e74c5e2277..0efef16a822 100644 --- a/tensorflow/python/keras/utils/io_utils.py +++ b/tensorflow/python/keras/utils/io_utils.py @@ -22,9 +22,10 @@ import collections import numpy as np import six +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import type_spec from tensorflow.python.util.tf_export import keras_export - try: import h5py except ImportError: @@ -149,6 +150,25 @@ class HDF5Matrix(object): """ return np.prod(self.shape) + @staticmethod + def _to_type_spec(value): + """Gets the Tensorflow TypeSpec corresponding to the passed dataset. + + Args: + value: A HDF5Matrix object. + + Returns: + A tf.TensorSpec. + """ + if not isinstance(value, HDF5Matrix): + raise TypeError('Expected value to be a HDF5Matrix, but saw: {}'.format( + type(value))) + return tensor_spec.TensorSpec(shape=value.shape, dtype=value.dtype) + + +type_spec.register_type_spec_from_value_converter(HDF5Matrix, + HDF5Matrix._to_type_spec) # pylint: disable=protected-access + def ask_to_proceed_with_overwrite(filepath): """Produces a prompt asking about overwriting a file.