From 4af9be964eff70b9f27f605e0f5b2cb04a5d03cc Mon Sep 17 00:00:00 2001 From: Amy Date: Mon, 11 Sep 2017 18:01:47 -0700 Subject: [PATCH] support passing in a source url to the mnist read_data_sets function, to make it easier to use 'fashion mnist' etc. (#12983) --- .../learn/python/learn/datasets/mnist.py | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index a90b9264f85..1f3295747e1 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -30,7 +30,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.platform import gfile # CVDF mirror of http://yann.lecun.com/exdb/mnist/ -SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' +DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' def _read32(bytestream): @@ -215,7 +215,8 @@ def read_data_sets(train_dir, dtype=dtypes.float32, reshape=True, validation_size=5000, - seed=None): + seed=None, + source_url=DEFAULT_SOURCE_URL): if fake_data: def fake(): @@ -227,28 +228,31 @@ def read_data_sets(train_dir, test = fake() return base.Datasets(train=train, validation=validation, test=test) + if not source_url: # empty string check + source_url = DEFAULT_SOURCE_URL + TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' local_file = base.maybe_download(TRAIN_IMAGES, train_dir, - SOURCE_URL + TRAIN_IMAGES) + source_url + TRAIN_IMAGES) with gfile.Open(local_file, 'rb') as f: train_images = extract_images(f) local_file = base.maybe_download(TRAIN_LABELS, train_dir, - SOURCE_URL + TRAIN_LABELS) + source_url + TRAIN_LABELS) with gfile.Open(local_file, 'rb') as f: train_labels = extract_labels(f, one_hot=one_hot) local_file = base.maybe_download(TEST_IMAGES, train_dir, - SOURCE_URL + TEST_IMAGES) + source_url + TEST_IMAGES) with gfile.Open(local_file, 'rb') as f: test_images = extract_images(f) local_file = base.maybe_download(TEST_LABELS, train_dir, - SOURCE_URL + TEST_LABELS) + source_url + TEST_LABELS) with gfile.Open(local_file, 'rb') as f: test_labels = extract_labels(f, one_hot=one_hot) @@ -262,13 +266,13 @@ def read_data_sets(train_dir, train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] - + options = dict(dtype=dtype, reshape=reshape, seed=seed) - + train = DataSet(train_images, train_labels, **options) validation = DataSet(validation_images, validation_labels, **options) test = DataSet(test_images, test_labels, **options) - + return base.Datasets(train=train, validation=validation, test=test)