-
Notifications
You must be signed in to change notification settings - Fork 3
/
helper.py
68 lines (54 loc) · 2.14 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import hashlib
from urllib.request import urlretrieve
import shutil
from tqdm import tqdm
def download_extract(data_path):
"""
Download and extract database
:param database_name: Database name
"""
database_name = 'CelebA'
url = 'https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip'
hash_code = '00d2c5bc6d35e252742224ab0c1e8fcb'
extract_path = os.path.join(data_path, 'img_align_celeba')
save_path = os.path.join(data_path, 'celeba.zip')
extract_fn = _unzip
if os.path.exists(extract_path):
print('Found {} Data'.format(database_name))
return
if not os.path.exists(data_path):
os.makedirs(data_path)
if not os.path.exists(save_path):
with DLProgress(unit='B', unit_scale=True, miniters=1, desc='Downloading {}'.format(database_name)) as pbar:
urlretrieve(
url,
save_path,
pbar.hook)
assert hashlib.md5(open(save_path, 'rb').read()).hexdigest() == hash_code, \
'{} file is corrupted. Remove the file and try again.'.format(save_path)
os.makedirs(extract_path)
try:
extract_fn(save_path, extract_path, database_name, data_path)
except Exception as err:
shutil.rmtree(extract_path) # Remove extraction folder if there is an error
raise err
# Remove compressed data
os.remove(save_path)
class DLProgress(tqdm):
"""
Handle Progress Bar while Downloading
"""
last_block = 0
def hook(self, block_num=1, block_size=1, total_size=None):
"""
A hook function that will be called once on establishment of the network connection and
once after each block read thereafter.
:param block_num: A count of blocks transferred so far
:param block_size: Block size in bytes
:param total_size: The total size of the file. This may be -1 on older FTP servers which do not return
a file size in response to a retrieval request.
"""
self.total = total_size
self.update((block_num - self.last_block) * block_size)
self.last_block = block_num