在examples目录下的许多例子都涉及到数据集下载,查看源码你会发现它们最终都是通过keras.utils.data_utils.get_file函数下载的,这是一个普适的数据集下载工具函数,所以有必要了解其功能,以便更好的应用。get_file函数签名如下:
get_file(fname, origin, untar=False, md5_hash=None, file_hash=None, cache_subdir='datasets', hash_algorithm='auto', extract=False, archive_format='auto', cache_dir=None)
fname指的是缓存到本地的文件名,origin其实就是数据集文件的下载地址,即URL,其它参数基本都是不言自明的。函数大致实现如下:
(1)依据cache_dir得到数据文件就存放的文件夹,即:datadir。正常情况下,应该是
if cache_dir is None: cache_dir = os.path.join(os.path.expanduser('~'), '.keras') datadir_base = os.path.expanduser(cache_dir) if not os.access(datadir_base, os.W_OK): datadir_base = os.path.join('/tmp', '.keras') datadir = os.path.join(datadir_base, cache_subdir) if not os.path.exists(datadir): os.makedirs(datadir) (2)由datadir和fname得到下载到本地的文件名:fpath if untar: untar_fpath = os.path.join(datadir, fname) fpath = untar_fpath + '.tar.gz'else: fpath = os.path.join(datadir, fname) (3)如果文件存在,则不用下载, download = False if os.path.exists(fpath): ...... else: download = True (4)否则,下载文件 if download: print('Downloading data from', origin) try: try: urlretrieve(origin, fpath, dl_progress) except HTTPError as e: raise Exception(error_msg.format(origin, e.code, e.msg)) except URLError as e: raise Exception(error_msg.format(origin, e.errno, e.reason)) except (Exception, KeyboardInterrupt): if os.path.exists(fpath): os.remove(fpath) raise (5)确定是否解压,如需要用extract(untar已过时),最后返回已下载的文件路径 if untar: if not os.path.exists(untar_fpath): _extract_archive(fpath, datadir, archive_format='tar') return untar_fpath if extract: _extract_archive(fpath, datadir, archive_format) return fpath
没有评论:
发表评论