!253 在线推理接口支持多batch输入

Merge pull request !253 from shishaoyu/master
This commit is contained in:
i-robot 2023-11-13 08:04:26 +00:00 committed by Gitee
commit 63ee9d5ca3
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
7 changed files with 162 additions and 151 deletions

1
OWNERS
View File

@ -10,3 +10,4 @@ reviewers:
- zhangshengdong
- kingsleyandher
- guo-xinjie-1
- matrixssy

View File

@ -278,10 +278,14 @@ class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC):
Parameters:
----------
input_ids(str | torch.Tensor):
input_ids(str | list | LongTensor):
The text entered by the user, e.g. 'Hello!'
Or
The text, which encoded by tokenizer, entered by the user, e.g. [0, 13, 5, ...]
Or
The List, which includes multi texts or tokens,
e.g. [['Hello!'], ["How are you?"]] | [[0, 13, 5, ...], [0, 21, ...]].
Notice that in beam-search mode multi texts or tokens is forbidden.
do_sample (`bool`, *optional*, defaults to `False`):
Whether to use sampling ; use greedy decoding otherwise.
top_k (`int`, *optional*, defaults to 0):
@ -420,37 +424,11 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
return engine
@staticmethod
def _broadcast_tokens(context_tokens, context_length, master_rank):
if dist.get_world_size() > 1:
if dist.get_rank() == master_rank:
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
dist.broadcast(context_tokens_tensor, master_rank)
else:
context_tokens_tensor = torch.empty(context_length,
dtype=torch.int64,
device=torch.device(get_accelerator().device_name()))
dist.broadcast(context_tokens_tensor, master_rank)
else:
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
return context_tokens_tensor
@staticmethod
def _check_output(output, stream):
if not stream:
full_output = None
for tmp in output:
full_output = tmp
return full_output
else:
return output
@staticmethod
def _ids_check(ids, tokenizer):
checked_ids = []
for per_ids in ids:
if torch.max(per_ids) >= len(tokenizer):
if per_ids == torch.Size([]) and torch.max(per_ids) >= len(tokenizer):
warning_info = "The output ids exceeds the tokenizer length, "\
"the clamp operation is enforced, please check!!"
logging.warning(warning_info)
@ -537,14 +515,9 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
# so you don't need to pass the prompt on
# each process.
# =======================================
context_length, context_tokens, master_rank = self._tokenize(input_ids)
# =======================================
# For parallel we need to send context tokens
# to other process
# =======================================
context_tokens_tensor = self._broadcast_tokens(context_tokens, context_length, master_rank).unsqueeze(0)
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
context_tokens, master_rank = self._tokenize(input_ids)
args.master_rank = master_rank
args.micro_batch_size = len(context_tokens)
# =======================================
# Get the streaming tokens generator
@ -568,12 +541,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
# Post processions in order to get final
# output texts/tokens
# =======================================
output = self._post_processing(token_stream,
context_length,
self.include_input,
self.detokenize,
self.num_beams)
return self._check_output(output, self.stream)
return self._token_generator(token_stream)
def _init_tokenizer(self, args):
if self.tokenizer_new is None:
@ -596,57 +564,94 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
else:
raise ValueError("Your tokenizer doesn't include eos_token.")
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def _tokenize(self, input_ids):
context_tokens = [[]]
broadcast_rank = torch.zeros(dist.get_world_size(),
dtype=torch.int64,
device=torch.device(get_accelerator().device_name()))
if input_ids:
if input_ids is not None and len(input_ids) > 0:
if isinstance(input_ids, str):
context_tokens = self.tokenizer.encode(input_ids)
context_tokens = [self.tokenizer.encode(input_ids)]
elif torch.is_tensor(input_ids):
if len(input_ids.shape) == 1:
context_tokens = input_ids.unsqueeze(0).numpy().tolist()
elif len(input_ids.shape) == 2:
context_tokens = input_ids.numpy().tolist()
elif isinstance(input_ids, (tuple, list)):
if len(input_ids) and isinstance(input_ids[0], (tuple, list)):
context_tokens = input_ids
elif len(input_ids) and isinstance(input_ids[0], int):
context_tokens = [input_ids]
elif len(input_ids) and isinstance(input_ids[0], str):
context_tokens = [self.tokenizer.encode(val) for val in input_ids]
else:
context_tokens = input_ids
raise TypeError("Please check input_ids in correct type.")
context_length = len(context_tokens)
counts = 1
broadcast_rank[dist.get_rank()] = 1
else:
context_tokens = [self.tokenizer.encode("EMPTY TEXT")]
context_length = 0
counts = 0
input_info = [counts, context_length]
input_info_tensor = get_accelerator().LongTensor(input_info)
dist.all_reduce(input_info_tensor)
dist.all_reduce(broadcast_rank)
counts = input_info_tensor[0].item()
if counts == 0:
raise ValueError("Please pass prompt on at least one process.")
context_length = input_info_tensor[1].item() // counts
master_rank = torch.nonzero(broadcast_rank)[0, 0]
return context_length, context_tokens, master_rank
def _post_processing(self, token_stream, context_length, include_input, detokenize, num_beams):
for output, _, log_probs in token_stream:
if not include_input:
output = [val[context_length:] for val in output]
return context_tokens, master_rank
if detokenize:
try:
output_checked = self._ids_check(output, self.tokenizer)
output = self.tokenizer.batch_decode(output_checked, skip_special_tokens=True)
except Exception as e:
error_info = "Meet errors when trying to decode the tokens. "\
"Please handle it by yourself."
logging.error(error_info)
logging.error(e)
def _post_processing(self, output, context_lengths, log_probs):
if not self.include_input:
output = [val[context_lengths[i]:] for i, val in enumerate(output)]
output = output[0] if len(output) == 1 else output
# When batch size > 1, you need truncate the tokens after eos_token_id
self._truncate_in_multi_batch(output)
if not self.return_output_log_probs:
yield output
else:
if num_beams == 1:
log_probs = [val[context_length:, :] for val in log_probs] if log_probs is not None else None
if self.detokenize:
try:
output_checked = self._ids_check(output, self.tokenizer)
output = self.tokenizer.batch_decode(output_checked, skip_special_tokens=True)
except Exception as e:
error_info = "Meet errors when trying to decode the tokens. "\
"Please handle it by yourself."
logging.error(error_info)
logging.error(e)
yield output, log_probs[0] if len(log_probs) == 1 else log_probs
output = output[0] if len(output) == 1 else output
if not self.return_output_log_probs:
res = output
else:
if self.num_beams == 1:
log_probs = [val[context_lengths[i]:, :] for i, val in enumerate(log_probs)] \
if log_probs is not None else None
res = output, log_probs[0] if len(log_probs) == 1 else log_probs
return res
def _truncate_in_multi_batch(self, output):
if len(output) > 1:
for idx, batch in enumerate(output):
trunc_index = torch.nonzero(batch == self.tokenizer.eos_token_id)
if min(trunc_index.shape):
output[idx][trunc_index.min():] = self.tokenizer.eos_token_id
def _yield(self, token_stream):
output, context_lengths, log_probs = None, None, None
for output, context_lengths, log_probs in token_stream:
if self.stream:
res = self._post_processing(output, context_lengths, log_probs)
yield res
if not self.stream:
yield self._post_processing(output, context_lengths, log_probs)
def _token_generator(self, token_stream):
token_stream = self._yield(token_stream)
if not self.stream:
full_output = None
for tmp in token_stream:
full_output = tmp
return full_output
else:
return token_stream

View File

@ -70,30 +70,26 @@ class ForwardStep:
self.pipeline_size_larger_than_one = (
args.pipeline_model_parallel_size > 1)
# Threshold of pipelining.
self.pipelining_batch_x_seqlen = \
args.inference_batch_times_seqlen_threshold
self.pipelining_batch_x_seqlen = args.inference_batch_times_seqlen_threshold
self.micro_batch_size = args.micro_batch_size
def __call__(self, tokens, position_ids, attention_mask):
"""Invocation of the forward methods. Note that self.inference_params
is being modified by the forward step."""
# Pipelining case.
if self.pipeline_size_larger_than_one:
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
micro_batch_size = \
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
return _with_pipelining_forward_step(self.model,
(tokens,
position_ids,
attention_mask),
self.inference_params,
micro_batch_size)
return _no_pipelining_forward_step(self.model,
(tokens,
position_ids,
attention_mask),
self.inference_params)
return _with_pipelining_forward_step(self.model,
(tokens,
position_ids,
attention_mask),
self.inference_params,
self.micro_batch_size)
else:
return _no_pipelining_forward_step(self.model,
(tokens,
position_ids,
attention_mask),
self.inference_params)
def _get_recv_buffer_dtype(args):
@ -157,12 +153,12 @@ def _with_pipelining_forward_step(model, inputs, inference_params, micro_batch_s
dtype=torch.float32, device=torch.cuda.current_device())
for micro_batch_index in range(num_micro_batches):
# Slice among the batch dimenion.
# Slice among the batch dimension.
start = micro_batch_index * micro_batch_size
end = min(start + micro_batch_size, batch_size)
this_micro_batch_size = end - start
tokens2use = tokens[start:end, ...]
position_ids2use = position_ids[start:end, ...]
tokens2use = tokens[start: end, ...]
position_ids2use = position_ids[start: end, ...]
output = _forward_step_helper(model,
tokens2use,
@ -176,7 +172,7 @@ def _with_pipelining_forward_step(model, inputs, inference_params, micro_batch_s
# Copy logits.
if parallel_state.is_pipeline_last_stage():
logits[start:end, ...] = output
logits[start: end, ...] = output
# Once we are done with all the micro-batches, we can
# adjust the sequence length offset.

View File

@ -22,10 +22,15 @@ def beam_search(model, tokens, **kwargs):
length_penalty = kwargs.pop("length_penalty", 1.0)
args = get_args()
if args.micro_batch_size > 1:
raise NotImplementedError("The input prompt nums should not greater than 1 "
"(i.e. micro_batch_size must be 1) in beam search mode.")
# ==========================
# Pad tokens
# ==========================
final_sequence_length, prompt_length, tokens = _pad_tokens(args, tokens)
final_sequence_length = args.max_length_ori
prompt_length, context_lengths, tokens = _pad_tokens(args, tokens, beam_size, num_return_gen)
# ==========================
# Forward step
@ -46,11 +51,11 @@ def beam_search(model, tokens, **kwargs):
# ==========================
with torch.no_grad():
tokens = tokens.repeat(beam_size, 1)
micro_batch_size, seq_length = tokens.size()
batch_size, seq_length = tokens.size()
attention_mask = torch.tril(torch.ones(
(micro_batch_size, seq_length, seq_length), device=tokens.device)).view(
micro_batch_size, 1, seq_length, seq_length)
(args.micro_batch_size, seq_length, seq_length), device=tokens.device)).view(
args.micro_batch_size, 1, seq_length, seq_length)
attention_mask = (attention_mask < 0.5)
position_ids = torch.arange(seq_length, dtype=torch.long,
device=tokens.device)
@ -66,24 +71,25 @@ def beam_search(model, tokens, **kwargs):
num_return_gen=num_return_gen,
position_ids=position_ids,
prompt_length=prompt_length,
context_lengths=context_lengths,
scores=scores,
stop_token=stop_token,
tokens=tokens)
output_scores, output_tokens = _beam_search_post_process(beam_hyp=beam_hyp,
beam_size=beam_size,
context_length=context_length,
done=done,
num_return_gen=num_return_gen,
output_scores=output_scores,
output_tokens=output_tokens,
context_length=context_length,
prompt_length=prompt_length,
scores=scores,
scores_size_tensor=scores_size_tensor,
tokens=tokens,
tokens_size_tensor=tokens_size_tensor)
yield output_tokens, None, torch.exp(output_scores)
yield output_tokens, context_lengths, torch.exp(output_scores)
def forward_loop(args, **kwargs):
@ -96,6 +102,7 @@ def forward_loop(args, **kwargs):
num_return_gen = kwargs.pop("num_return_gen")
position_ids = kwargs.pop("position_ids")
prompt_length = kwargs.pop("prompt_length")
context_lengths = kwargs.pop("context_lengths")
scores = kwargs.pop("scores")
stop_token = kwargs.pop("stop_token")
tokens = kwargs.pop("tokens")
@ -134,8 +141,8 @@ def forward_loop(args, **kwargs):
tokens = broadcast_from_last_pipeline_stage(tokens.size(), torch.int64, tokens)
yield tokens[:num_return_gen], None, torch.exp(scores[:num_return_gen])
yield tokens[:num_return_gen], context_lengths, torch.exp(scores[:num_return_gen])
output_info = (context_length, done, scores, tokens)
return output_info
@ -293,17 +300,9 @@ def _beam_candidates_at_beginning(args, beam_size, new_scores):
return indices, sorted_scores
def _pad_tokens(args, tokens):
def _pad_tokens(args, tokens, beam_size, num_return_gen):
tokens, lengths = pad_batch(tokens, args)
tokens = get_accelerator().LongTensor(tokens)
prompt_length = min(lengths)
if args.text_generation_config['max_new_tokens'] > 0:
final_sequence_length = prompt_length + args.text_generation_config['max_new_tokens']
else:
final_sequence_length = args.text_generation_config['max_length']
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
# If the context is too big, this happens
if prompt_length >= final_sequence_length:
raise ValueError("The length of your input text exceeds the maximum. "
"Please increase the value of 'max_length'.")
return final_sequence_length, prompt_length, tokens
prompt_length = lengths.min().item()
lengths = lengths.repeat(min(beam_size, num_return_gen)).cpu().numpy().tolist()
return prompt_length, lengths, tokens

View File

@ -28,7 +28,7 @@ def get_batch(context_tokens):
tokenizer = get_tokenizer()
# Move to GPU.
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().to(get_accelerator().device_name())
tokens = context_tokens.contiguous().to(get_accelerator().device_name())
# Get the attention mask and position ids.
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
tokens,
@ -41,29 +41,45 @@ def get_batch(context_tokens):
def pad_batch(batch, args):
max_context_length = get_accelerator().LongTensor([max(len(val) for val in batch)])
torch.distributed.all_reduce(max_context_length)
max_context_length = torch.div(max_context_length, torch.distributed.get_world_size(), rounding_mode="floor")
tokenizer = get_tokenizer()
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
context_lengths = [len(val) for val in batch]
if args.text_generation_config['max_new_tokens'] > 0:
max_length = max(context_lengths) + args.text_generation_config['max_new_tokens']
max_length = max_context_length[0].item() + args.text_generation_config['max_new_tokens']
else:
max_length = args.text_generation_config['max_length']
# set fused_operator_contiguous_num = 32
max_length = math.ceil(max_length / 32) * 32
max_length_padded = math.ceil(max_length / 32) * 32
for i, tokens in enumerate(batch):
if context_lengths[i] < max_length:
tokens.extend([pad_id] * (max_length - context_lengths[i]))
if context_lengths[i] < max_length_padded:
tokens.extend([pad_id] * (max_length_padded - context_lengths[i]))
return batch, context_lengths
context_tokens_tensor = get_accelerator().LongTensor(batch)
context_length_tensor = get_accelerator().LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor, args.master_rank)
torch.distributed.broadcast(context_tokens_tensor, args.master_rank)
args.seq_length = context_tokens_tensor.shape[1]
args.max_position_embeddings = args.seq_length
args.max_length_ori = max_length
return context_tokens_tensor, context_length_tensor
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
"""
This function has been mostly taken from huggingface conversational ai code
This function has been mostly taken from huggingface conversational ai code at
https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer
-learning-2d818ac26313
"""
if top_k > 0:
@ -98,16 +114,7 @@ def greedy_search_or_sampling(model, context_tokens, model_latencies=None, singl
model_latencies = [] if model_latencies is None else model_latencies
single_token_latency = [] if single_token_latency is None else single_token_latency
context_tokens, context_lengths = pad_batch(context_tokens, args)
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
context_length_tensor = get_accelerator().LongTensor(context_lengths)
torch.distributed.broadcast(context_length_tensor,
parallel_state.get_tensor_model_parallel_src_rank(),
group=parallel_state.get_tensor_model_parallel_group())
torch.distributed.broadcast(context_tokens_tensor,
parallel_state.get_tensor_model_parallel_src_rank(),
group=parallel_state.get_tensor_model_parallel_group())
context_tokens_tensor, context_length_tensor = pad_batch(context_tokens, args)
context_length = context_length_tensor.min().item()
@ -118,13 +125,17 @@ def greedy_search_or_sampling(model, context_tokens, model_latencies=None, singl
model_latencies=model_latencies
)
count = 0
yield from _post_process(batch_token_iterator, context_length, count, single_token_latency)
yield from _post_process(
batch_token_iterator,
context_length,
context_length_tensor,
single_token_latency
)
def _post_process(batch_token_iterator, context_length, count, single_token_latency):
def _post_process(batch_token_iterator, context_length, context_lengths, single_token_latency):
t0 = time.time()
count = 0
for tokens, lengths, log_probs in batch_token_iterator:
if count > 1:
get_accelerator().synchronize()
@ -135,13 +146,12 @@ def _post_process(batch_token_iterator, context_length, count, single_token_late
count += 1
context_length += 1
if tokens is not None:
yield tokens[:, :context_length], lengths, log_probs
yield tokens[:, :context_length], context_lengths.cpu().numpy().tolist(), log_probs
else:
yield None, None, None
def switch(val1, val2, boolean):
boolean = boolean.type_as(val1)
return (1 - boolean) * val1 + boolean * val2
@ -165,9 +175,10 @@ def forward_step(model, tokens, **kwargs):
t0 = time.time()
args = get_args()
orig_seq_length = args.seq_length
args.seq_length = tokens.shape[1]
args.micro_batch_size = tokens.shape[0]
config = get_model_config(model)
tensor_shapes = [args.seq_length, args.micro_batch_size, args.hidden_size]
input_tensor = recv_forward(tensor_shapes, config)
_unwrap_and_set_input_tensor(args, input_tensor, model)
@ -241,10 +252,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, type_ids=None,
counter = 0
layer_past = None
batch_size = tokens.size(0)
max_length = tokens.size(1)
max_length = args.max_length_ori
context_length = context_lengths.min().item()
is_done = torch.zeros([batch_size]).byte().to(get_accelerator().device_name())
lengths = torch.ones([batch_size]).long().to(get_accelerator().device_name()) * max_length
while context_length < max_length:
if args.text_generation_config['recompute']:
@ -277,7 +287,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, type_ids=None,
done = _is_done(is_done, prev, started, tokenizer)
yield tokens, lengths, output_log_probs
yield tokens, max_length, output_log_probs
context_length += 1
counter += 1

View File

@ -33,5 +33,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference
--micro-batch-size 1 \
--seq-length 1024 \
--max-new-tokens 256 \
--seed 42
--seed 42 \
--position-embedding-type alibi \

View File

@ -62,7 +62,7 @@ def task_do_sample(args, model, tokenizer=None, system_template="", dialog_templ
t = time.time()
output = model.generate(
instruction,
[instruction, instruction],
do_sample=True,
top_k=args.top_k,
top_p=args.top_p,