Merge pull request #11 from xuyige/master

update config loader
This commit is contained in:
Coet 2018-07-14 16:11:02 +08:00 committed by GitHub
commit 4a82b5d446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 93 additions and 7 deletions

54
fastNLP/loader/config Normal file
View File

@ -0,0 +1,54 @@
[General]
revision = "first"
datapath = "./data/smallset/imdb/"
embed_path = "./data/smallset/imdb/embedding.txt"
optimizer = "adam"
attn_mode = "rout"
seq_encoder = "bilstm"
out_caps_num = 5
rout_iter = 3
max_snt_num = 40
max_wd_num = 40
max_epochs = 50
pre_trained = true
batch_sz = 32
batch_sz_min = 32
bucket_sz = 5000
partial_update_until_epoch = 2
embed_size = 300
hidden_size = 200
dense_hidden = [300, 10]
lr = 0.0002
decay_steps = 1000
decay_rate = 0.9
dropout = 0.2
early_stopping = 7
reg = 1e-06
[My]
datapath = "./data/smallset/imdb/"
embed_path = "./data/smallset/imdb/embedding.txt"
optimizer = "adam"
attn_mode = "rout"
seq_encoder = "bilstm"
out_caps_num = 5
rout_iter = 3
max_snt_num = 40
max_wd_num = 40
max_epochs = 50
pre_trained = true
batch_sz = 32
batch_sz_min = 32
bucket_sz = 5000
partial_update_until_epoch = 2
embed_size = 300
hidden_size = 200
dense_hidden = [300, 10]
lr = 0.0002
decay_steps = 1000
decay_rate = 0.9
dropout = 0.2
early_stopping = 70
reg = 1e-05
test = 5
new_attr = 40

View File

@ -25,17 +25,49 @@ class ConfigLoader(BaseLoader):
cfg = configparser.ConfigParser()
cfg.read(file_path)
for s in sections:
attr_list = [i for i in type(sections[s]).__dict__.keys() if
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in attr_list:
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
print(s, attr, val, type(val))
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
#print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
# attribute attr in section s did not been set, default val will be used
print("cannot load attribute %s in section %s"
% (attr, s))
pass
if __name__ == "__name__":
config = ConfigLoader('configLoader', 'there is no data')
class ConfigSection(object):
def __init__(self):
pass
section = {'General': ConfigSection(), 'My': ConfigSection(), 'A': ConfigSection()}
"""
General and My can be found in config file, so the attr and
value will be updated
A cannot be found in config file, so nothing will be done
"""
config.load_config("config", section)
for s in section:
print(s)
for attr in section[s].__dict__.keys():
print(s, attr, getattr(section[s], attr))