mirror of
https://gitee.com/fastnlp/fastNLP.git
synced 2024-12-04 21:28:01 +08:00
DataSet apply的时候可以传入use_tqdm和tqdm_desc
This commit is contained in:
parent
4886dbff8b
commit
04ad8e604e
@ -371,6 +371,10 @@ from .field import SetInputOrTargetException
|
||||
from .instance import Instance
|
||||
from .utils import pretty_table_printer
|
||||
from .collate_fn import Collater
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except:
|
||||
from .utils import _pseudo_tqdm as tqdm
|
||||
|
||||
|
||||
class ApplyResultException(Exception):
|
||||
@ -860,6 +864,11 @@ class DataSet(object):
|
||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
|
||||
|
||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
|
||||
"""
|
||||
assert len(self) != 0, "Null DataSet cannot use apply_field()."
|
||||
@ -887,6 +896,10 @@ class DataSet(object):
|
||||
|
||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
:return Dict[str:Field]: 返回一个字典
|
||||
"""
|
||||
assert len(self) != 0, "Null DataSet cannot use apply_field()."
|
||||
@ -949,6 +962,10 @@ class DataSet(object):
|
||||
|
||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
:return Dict[str:Field]: 返回一个字典
|
||||
"""
|
||||
# 返回 dict , 检查是否一直相同
|
||||
@ -957,7 +974,9 @@ class DataSet(object):
|
||||
idx = -1
|
||||
try:
|
||||
results = {}
|
||||
for idx, ins in enumerate(self._inner_iter()):
|
||||
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True,
|
||||
desc=kwargs.get('tqdm_desc', ''),
|
||||
leave=False, disable=not kwargs.get('use_tqdm', False)):
|
||||
if "_apply_field" in kwargs:
|
||||
res = func(ins[kwargs["_apply_field"]])
|
||||
else:
|
||||
@ -1002,6 +1021,10 @@ class DataSet(object):
|
||||
|
||||
3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否使用tqdm显示预处理进度
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
:return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
|
||||
"""
|
||||
assert callable(func), "The func you provide is not callable."
|
||||
@ -1009,7 +1032,9 @@ class DataSet(object):
|
||||
idx = -1
|
||||
try:
|
||||
results = []
|
||||
for idx, ins in enumerate(self._inner_iter()):
|
||||
for idx, ins in tqdm(enumerate(self._inner_iter()), total=len(self), dynamic_ncols=True, leave=False,
|
||||
desc=kwargs.get('tqdm_desc', ''),
|
||||
disable=not kwargs.get('use_tqdm', False)):
|
||||
if "_apply_field" in kwargs:
|
||||
results.append(func(ins[kwargs["_apply_field"]]))
|
||||
else:
|
||||
|
@ -321,8 +321,15 @@ class DataBundle:
|
||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
|
||||
|
||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否显示tqdm进度条
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
"""
|
||||
tqdm_desc = kwargs.get('tqdm_desc', '')
|
||||
for name, dataset in self.datasets.items():
|
||||
if tqdm_desc != '':
|
||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
|
||||
if dataset.has_field(field_name=field_name):
|
||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs)
|
||||
elif not ignore_miss_dataset:
|
||||
@ -350,10 +357,17 @@ class DataBundle:
|
||||
|
||||
3. ignore_type: bool, 如果为True则将被修改的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否显示tqdm进度条
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
|
||||
"""
|
||||
res = {}
|
||||
tqdm_desc = kwargs.get('tqdm_desc', '')
|
||||
for name, dataset in self.datasets.items():
|
||||
if tqdm_desc != '':
|
||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
|
||||
if dataset.has_field(field_name=field_name):
|
||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, modify_fields=modify_fields, **kwargs)
|
||||
elif not ignore_miss_dataset:
|
||||
@ -376,8 +390,16 @@ class DataBundle:
|
||||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
|
||||
|
||||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否显示tqdm进度条
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
"""
|
||||
tqdm_desc = kwargs.get('tqdm_desc', '')
|
||||
for name, dataset in self.datasets.items():
|
||||
if tqdm_desc != '':
|
||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
|
||||
dataset.apply(func, new_field_name=new_field_name, **kwargs)
|
||||
return self
|
||||
|
||||
@ -399,10 +421,17 @@ class DataBundle:
|
||||
|
||||
3. ignore_type: bool, 如果为True则将被修改的的field的ignore_type设置为true, 忽略其类型
|
||||
|
||||
4. use_tqdm: bool, 是否显示tqdm进度条
|
||||
|
||||
5. tqdm_desc: str, 当use_tqdm为True时,可以显示当前tqdm正在处理的名称
|
||||
|
||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
|
||||
"""
|
||||
res = {}
|
||||
tqdm_desc = kwargs.get('tqdm_desc', '')
|
||||
for name, dataset in self.datasets.items():
|
||||
if tqdm_desc!='':
|
||||
kwargs['tqdm_desc'] = tqdm_desc + f' for `{name}`'
|
||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, **kwargs)
|
||||
return res
|
||||
|
||||
|
@ -136,6 +136,14 @@ class TestDataSetMethods(unittest.TestCase):
|
||||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k", ignore_type=True)
|
||||
# expect no exception raised
|
||||
|
||||
def test_apply_tqdm(self):
|
||||
import time
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
def do_nothing(ins):
|
||||
time.sleep(0.01)
|
||||
ds.apply(do_nothing, use_tqdm=True)
|
||||
ds.apply_field(do_nothing, field_name='x', use_tqdm=True)
|
||||
|
||||
def test_apply_cannot_modify_instance(self):
|
||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
|
||||
def modify_inplace(instance):
|
||||
|
Loading…
Reference in New Issue
Block a user