support passing in a source url to the mnist read_data_sets function, to make it easier to use 'fashion mnist' etc. (#12983)
This commit is contained in:
parent
79517578de
commit
4af9be964e
@ -30,7 +30,7 @@ from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.platform import gfile
|
||||
|
||||
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
|
||||
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
|
||||
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
|
||||
|
||||
|
||||
def _read32(bytestream):
|
||||
@ -215,7 +215,8 @@ def read_data_sets(train_dir,
|
||||
dtype=dtypes.float32,
|
||||
reshape=True,
|
||||
validation_size=5000,
|
||||
seed=None):
|
||||
seed=None,
|
||||
source_url=DEFAULT_SOURCE_URL):
|
||||
if fake_data:
|
||||
|
||||
def fake():
|
||||
@ -227,28 +228,31 @@ def read_data_sets(train_dir,
|
||||
test = fake()
|
||||
return base.Datasets(train=train, validation=validation, test=test)
|
||||
|
||||
if not source_url: # empty string check
|
||||
source_url = DEFAULT_SOURCE_URL
|
||||
|
||||
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
|
||||
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
|
||||
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
|
||||
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
|
||||
|
||||
local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
|
||||
SOURCE_URL + TRAIN_IMAGES)
|
||||
source_url + TRAIN_IMAGES)
|
||||
with gfile.Open(local_file, 'rb') as f:
|
||||
train_images = extract_images(f)
|
||||
|
||||
local_file = base.maybe_download(TRAIN_LABELS, train_dir,
|
||||
SOURCE_URL + TRAIN_LABELS)
|
||||
source_url + TRAIN_LABELS)
|
||||
with gfile.Open(local_file, 'rb') as f:
|
||||
train_labels = extract_labels(f, one_hot=one_hot)
|
||||
|
||||
local_file = base.maybe_download(TEST_IMAGES, train_dir,
|
||||
SOURCE_URL + TEST_IMAGES)
|
||||
source_url + TEST_IMAGES)
|
||||
with gfile.Open(local_file, 'rb') as f:
|
||||
test_images = extract_images(f)
|
||||
|
||||
local_file = base.maybe_download(TEST_LABELS, train_dir,
|
||||
SOURCE_URL + TEST_LABELS)
|
||||
source_url + TEST_LABELS)
|
||||
with gfile.Open(local_file, 'rb') as f:
|
||||
test_labels = extract_labels(f, one_hot=one_hot)
|
||||
|
||||
@ -262,13 +266,13 @@ def read_data_sets(train_dir,
|
||||
train_images = train_images[validation_size:]
|
||||
train_labels = train_labels[validation_size:]
|
||||
|
||||
|
||||
|
||||
options = dict(dtype=dtype, reshape=reshape, seed=seed)
|
||||
|
||||
|
||||
train = DataSet(train_images, train_labels, **options)
|
||||
validation = DataSet(validation_images, validation_labels, **options)
|
||||
test = DataSet(test_images, test_labels, **options)
|
||||
|
||||
|
||||
return base.Datasets(train=train, validation=validation, test=test)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user