mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
commit
63ee9d5ca3
1
OWNERS
1
OWNERS
@ -10,3 +10,4 @@ reviewers:
|
||||
- zhangshengdong
|
||||
- kingsleyandher
|
||||
- guo-xinjie-1
|
||||
- matrixssy
|
||||
|
@ -278,10 +278,14 @@ class MegatronModuleForCausalLMABC(torch.nn.Module, abc.ABC):
|
||||
|
||||
Parameters:
|
||||
----------
|
||||
input_ids(str | torch.Tensor):
|
||||
input_ids(str | list | LongTensor):
|
||||
The text entered by the user, e.g. 'Hello!'
|
||||
Or
|
||||
The text, which encoded by tokenizer, entered by the user, e.g. [0, 13, 5, ...]
|
||||
Or
|
||||
The List, which includes multi texts or tokens,
|
||||
e.g. [['Hello!'], ["How are you?"]] | [[0, 13, 5, ...], [0, 21, ...]].
|
||||
Notice that in beam-search mode multi texts or tokens is forbidden.
|
||||
do_sample (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sampling ; use greedy decoding otherwise.
|
||||
top_k (`int`, *optional*, defaults to 0):
|
||||
@ -420,37 +424,11 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
|
||||
return engine
|
||||
|
||||
@staticmethod
|
||||
def _broadcast_tokens(context_tokens, context_length, master_rank):
|
||||
if dist.get_world_size() > 1:
|
||||
if dist.get_rank() == master_rank:
|
||||
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
|
||||
dist.broadcast(context_tokens_tensor, master_rank)
|
||||
else:
|
||||
context_tokens_tensor = torch.empty(context_length,
|
||||
dtype=torch.int64,
|
||||
device=torch.device(get_accelerator().device_name()))
|
||||
dist.broadcast(context_tokens_tensor, master_rank)
|
||||
else:
|
||||
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
|
||||
|
||||
return context_tokens_tensor
|
||||
|
||||
@staticmethod
|
||||
def _check_output(output, stream):
|
||||
if not stream:
|
||||
full_output = None
|
||||
for tmp in output:
|
||||
full_output = tmp
|
||||
return full_output
|
||||
else:
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _ids_check(ids, tokenizer):
|
||||
checked_ids = []
|
||||
for per_ids in ids:
|
||||
if torch.max(per_ids) >= len(tokenizer):
|
||||
if per_ids == torch.Size([]) and torch.max(per_ids) >= len(tokenizer):
|
||||
warning_info = "The output ids exceeds the tokenizer length, "\
|
||||
"the clamp operation is enforced, please check!!"
|
||||
logging.warning(warning_info)
|
||||
@ -537,14 +515,9 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
# so you don't need to pass the prompt on
|
||||
# each process.
|
||||
# =======================================
|
||||
context_length, context_tokens, master_rank = self._tokenize(input_ids)
|
||||
|
||||
# =======================================
|
||||
# For parallel we need to send context tokens
|
||||
# to other process
|
||||
# =======================================
|
||||
context_tokens_tensor = self._broadcast_tokens(context_tokens, context_length, master_rank).unsqueeze(0)
|
||||
context_tokens = context_tokens_tensor.cpu().numpy().tolist()
|
||||
context_tokens, master_rank = self._tokenize(input_ids)
|
||||
args.master_rank = master_rank
|
||||
args.micro_batch_size = len(context_tokens)
|
||||
|
||||
# =======================================
|
||||
# Get the streaming tokens generator
|
||||
@ -568,12 +541,7 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
# Post processions in order to get final
|
||||
# output texts/tokens
|
||||
# =======================================
|
||||
output = self._post_processing(token_stream,
|
||||
context_length,
|
||||
self.include_input,
|
||||
self.detokenize,
|
||||
self.num_beams)
|
||||
return self._check_output(output, self.stream)
|
||||
return self._token_generator(token_stream)
|
||||
|
||||
def _init_tokenizer(self, args):
|
||||
if self.tokenizer_new is None:
|
||||
@ -596,57 +564,94 @@ class MegatronModuleForCausalLM(MegatronModuleForCausalLMABC):
|
||||
else:
|
||||
raise ValueError("Your tokenizer doesn't include eos_token.")
|
||||
|
||||
if self.tokenizer.pad_token_id is None:
|
||||
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
||||
|
||||
def _tokenize(self, input_ids):
|
||||
context_tokens = [[]]
|
||||
broadcast_rank = torch.zeros(dist.get_world_size(),
|
||||
dtype=torch.int64,
|
||||
device=torch.device(get_accelerator().device_name()))
|
||||
|
||||
if input_ids:
|
||||
if input_ids is not None and len(input_ids) > 0:
|
||||
if isinstance(input_ids, str):
|
||||
context_tokens = self.tokenizer.encode(input_ids)
|
||||
context_tokens = [self.tokenizer.encode(input_ids)]
|
||||
elif torch.is_tensor(input_ids):
|
||||
if len(input_ids.shape) == 1:
|
||||
context_tokens = input_ids.unsqueeze(0).numpy().tolist()
|
||||
elif len(input_ids.shape) == 2:
|
||||
context_tokens = input_ids.numpy().tolist()
|
||||
elif isinstance(input_ids, (tuple, list)):
|
||||
if len(input_ids) and isinstance(input_ids[0], (tuple, list)):
|
||||
context_tokens = input_ids
|
||||
elif len(input_ids) and isinstance(input_ids[0], int):
|
||||
context_tokens = [input_ids]
|
||||
elif len(input_ids) and isinstance(input_ids[0], str):
|
||||
context_tokens = [self.tokenizer.encode(val) for val in input_ids]
|
||||
else:
|
||||
context_tokens = input_ids
|
||||
raise TypeError("Please check input_ids in correct type.")
|
||||
|
||||
context_length = len(context_tokens)
|
||||
counts = 1
|
||||
broadcast_rank[dist.get_rank()] = 1
|
||||
else:
|
||||
context_tokens = [self.tokenizer.encode("EMPTY TEXT")]
|
||||
context_length = 0
|
||||
counts = 0
|
||||
|
||||
input_info = [counts, context_length]
|
||||
input_info_tensor = get_accelerator().LongTensor(input_info)
|
||||
dist.all_reduce(input_info_tensor)
|
||||
dist.all_reduce(broadcast_rank)
|
||||
counts = input_info_tensor[0].item()
|
||||
if counts == 0:
|
||||
raise ValueError("Please pass prompt on at least one process.")
|
||||
context_length = input_info_tensor[1].item() // counts
|
||||
master_rank = torch.nonzero(broadcast_rank)[0, 0]
|
||||
return context_length, context_tokens, master_rank
|
||||
|
||||
def _post_processing(self, token_stream, context_length, include_input, detokenize, num_beams):
|
||||
for output, _, log_probs in token_stream:
|
||||
if not include_input:
|
||||
output = [val[context_length:] for val in output]
|
||||
return context_tokens, master_rank
|
||||
|
||||
if detokenize:
|
||||
try:
|
||||
output_checked = self._ids_check(output, self.tokenizer)
|
||||
output = self.tokenizer.batch_decode(output_checked, skip_special_tokens=True)
|
||||
except Exception as e:
|
||||
error_info = "Meet errors when trying to decode the tokens. "\
|
||||
"Please handle it by yourself."
|
||||
logging.error(error_info)
|
||||
logging.error(e)
|
||||
def _post_processing(self, output, context_lengths, log_probs):
|
||||
if not self.include_input:
|
||||
output = [val[context_lengths[i]:] for i, val in enumerate(output)]
|
||||
|
||||
output = output[0] if len(output) == 1 else output
|
||||
# When batch size > 1, you need truncate the tokens after eos_token_id
|
||||
self._truncate_in_multi_batch(output)
|
||||
|
||||
if not self.return_output_log_probs:
|
||||
yield output
|
||||
else:
|
||||
if num_beams == 1:
|
||||
log_probs = [val[context_length:, :] for val in log_probs] if log_probs is not None else None
|
||||
if self.detokenize:
|
||||
try:
|
||||
output_checked = self._ids_check(output, self.tokenizer)
|
||||
output = self.tokenizer.batch_decode(output_checked, skip_special_tokens=True)
|
||||
except Exception as e:
|
||||
error_info = "Meet errors when trying to decode the tokens. "\
|
||||
"Please handle it by yourself."
|
||||
logging.error(error_info)
|
||||
logging.error(e)
|
||||
|
||||
yield output, log_probs[0] if len(log_probs) == 1 else log_probs
|
||||
output = output[0] if len(output) == 1 else output
|
||||
|
||||
if not self.return_output_log_probs:
|
||||
res = output
|
||||
else:
|
||||
if self.num_beams == 1:
|
||||
log_probs = [val[context_lengths[i]:, :] for i, val in enumerate(log_probs)] \
|
||||
if log_probs is not None else None
|
||||
|
||||
res = output, log_probs[0] if len(log_probs) == 1 else log_probs
|
||||
|
||||
return res
|
||||
|
||||
def _truncate_in_multi_batch(self, output):
|
||||
if len(output) > 1:
|
||||
for idx, batch in enumerate(output):
|
||||
trunc_index = torch.nonzero(batch == self.tokenizer.eos_token_id)
|
||||
|
||||
if min(trunc_index.shape):
|
||||
output[idx][trunc_index.min():] = self.tokenizer.eos_token_id
|
||||
|
||||
def _yield(self, token_stream):
|
||||
output, context_lengths, log_probs = None, None, None
|
||||
for output, context_lengths, log_probs in token_stream:
|
||||
if self.stream:
|
||||
res = self._post_processing(output, context_lengths, log_probs)
|
||||
yield res
|
||||
|
||||
if not self.stream:
|
||||
yield self._post_processing(output, context_lengths, log_probs)
|
||||
|
||||
def _token_generator(self, token_stream):
|
||||
token_stream = self._yield(token_stream)
|
||||
if not self.stream:
|
||||
full_output = None
|
||||
for tmp in token_stream:
|
||||
full_output = tmp
|
||||
return full_output
|
||||
else:
|
||||
return token_stream
|
||||
|
@ -70,30 +70,26 @@ class ForwardStep:
|
||||
self.pipeline_size_larger_than_one = (
|
||||
args.pipeline_model_parallel_size > 1)
|
||||
# Threshold of pipelining.
|
||||
self.pipelining_batch_x_seqlen = \
|
||||
args.inference_batch_times_seqlen_threshold
|
||||
self.pipelining_batch_x_seqlen = args.inference_batch_times_seqlen_threshold
|
||||
self.micro_batch_size = args.micro_batch_size
|
||||
|
||||
def __call__(self, tokens, position_ids, attention_mask):
|
||||
"""Invocation of the forward methods. Note that self.inference_params
|
||||
is being modified by the forward step."""
|
||||
# Pipelining case.
|
||||
if self.pipeline_size_larger_than_one:
|
||||
current_batch_x_seqlen = tokens.size(0) * tokens.size(1)
|
||||
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen:
|
||||
micro_batch_size = \
|
||||
max(1, self.pipelining_batch_x_seqlen // tokens.size(1))
|
||||
return _with_pipelining_forward_step(self.model,
|
||||
(tokens,
|
||||
position_ids,
|
||||
attention_mask),
|
||||
self.inference_params,
|
||||
micro_batch_size)
|
||||
|
||||
return _no_pipelining_forward_step(self.model,
|
||||
(tokens,
|
||||
position_ids,
|
||||
attention_mask),
|
||||
self.inference_params)
|
||||
return _with_pipelining_forward_step(self.model,
|
||||
(tokens,
|
||||
position_ids,
|
||||
attention_mask),
|
||||
self.inference_params,
|
||||
self.micro_batch_size)
|
||||
else:
|
||||
return _no_pipelining_forward_step(self.model,
|
||||
(tokens,
|
||||
position_ids,
|
||||
attention_mask),
|
||||
self.inference_params)
|
||||
|
||||
|
||||
def _get_recv_buffer_dtype(args):
|
||||
@ -157,12 +153,12 @@ def _with_pipelining_forward_step(model, inputs, inference_params, micro_batch_s
|
||||
dtype=torch.float32, device=torch.cuda.current_device())
|
||||
|
||||
for micro_batch_index in range(num_micro_batches):
|
||||
# Slice among the batch dimenion.
|
||||
# Slice among the batch dimension.
|
||||
start = micro_batch_index * micro_batch_size
|
||||
end = min(start + micro_batch_size, batch_size)
|
||||
this_micro_batch_size = end - start
|
||||
tokens2use = tokens[start:end, ...]
|
||||
position_ids2use = position_ids[start:end, ...]
|
||||
tokens2use = tokens[start: end, ...]
|
||||
position_ids2use = position_ids[start: end, ...]
|
||||
|
||||
output = _forward_step_helper(model,
|
||||
tokens2use,
|
||||
@ -176,7 +172,7 @@ def _with_pipelining_forward_step(model, inputs, inference_params, micro_batch_s
|
||||
|
||||
# Copy logits.
|
||||
if parallel_state.is_pipeline_last_stage():
|
||||
logits[start:end, ...] = output
|
||||
logits[start: end, ...] = output
|
||||
|
||||
# Once we are done with all the micro-batches, we can
|
||||
# adjust the sequence length offset.
|
||||
|
@ -22,10 +22,15 @@ def beam_search(model, tokens, **kwargs):
|
||||
length_penalty = kwargs.pop("length_penalty", 1.0)
|
||||
args = get_args()
|
||||
|
||||
if args.micro_batch_size > 1:
|
||||
raise NotImplementedError("The input prompt nums should not greater than 1 "
|
||||
"(i.e. micro_batch_size must be 1) in beam search mode.")
|
||||
|
||||
# ==========================
|
||||
# Pad tokens
|
||||
# ==========================
|
||||
final_sequence_length, prompt_length, tokens = _pad_tokens(args, tokens)
|
||||
final_sequence_length = args.max_length_ori
|
||||
prompt_length, context_lengths, tokens = _pad_tokens(args, tokens, beam_size, num_return_gen)
|
||||
|
||||
# ==========================
|
||||
# Forward step
|
||||
@ -46,11 +51,11 @@ def beam_search(model, tokens, **kwargs):
|
||||
# ==========================
|
||||
with torch.no_grad():
|
||||
tokens = tokens.repeat(beam_size, 1)
|
||||
micro_batch_size, seq_length = tokens.size()
|
||||
batch_size, seq_length = tokens.size()
|
||||
|
||||
attention_mask = torch.tril(torch.ones(
|
||||
(micro_batch_size, seq_length, seq_length), device=tokens.device)).view(
|
||||
micro_batch_size, 1, seq_length, seq_length)
|
||||
(args.micro_batch_size, seq_length, seq_length), device=tokens.device)).view(
|
||||
args.micro_batch_size, 1, seq_length, seq_length)
|
||||
attention_mask = (attention_mask < 0.5)
|
||||
position_ids = torch.arange(seq_length, dtype=torch.long,
|
||||
device=tokens.device)
|
||||
@ -66,24 +71,25 @@ def beam_search(model, tokens, **kwargs):
|
||||
num_return_gen=num_return_gen,
|
||||
position_ids=position_ids,
|
||||
prompt_length=prompt_length,
|
||||
context_lengths=context_lengths,
|
||||
scores=scores,
|
||||
stop_token=stop_token,
|
||||
tokens=tokens)
|
||||
|
||||
output_scores, output_tokens = _beam_search_post_process(beam_hyp=beam_hyp,
|
||||
beam_size=beam_size,
|
||||
context_length=context_length,
|
||||
done=done,
|
||||
num_return_gen=num_return_gen,
|
||||
output_scores=output_scores,
|
||||
output_tokens=output_tokens,
|
||||
context_length=context_length,
|
||||
prompt_length=prompt_length,
|
||||
scores=scores,
|
||||
scores_size_tensor=scores_size_tensor,
|
||||
tokens=tokens,
|
||||
tokens_size_tensor=tokens_size_tensor)
|
||||
|
||||
yield output_tokens, None, torch.exp(output_scores)
|
||||
yield output_tokens, context_lengths, torch.exp(output_scores)
|
||||
|
||||
|
||||
def forward_loop(args, **kwargs):
|
||||
@ -96,6 +102,7 @@ def forward_loop(args, **kwargs):
|
||||
num_return_gen = kwargs.pop("num_return_gen")
|
||||
position_ids = kwargs.pop("position_ids")
|
||||
prompt_length = kwargs.pop("prompt_length")
|
||||
context_lengths = kwargs.pop("context_lengths")
|
||||
scores = kwargs.pop("scores")
|
||||
stop_token = kwargs.pop("stop_token")
|
||||
tokens = kwargs.pop("tokens")
|
||||
@ -134,8 +141,8 @@ def forward_loop(args, **kwargs):
|
||||
|
||||
tokens = broadcast_from_last_pipeline_stage(tokens.size(), torch.int64, tokens)
|
||||
|
||||
yield tokens[:num_return_gen], None, torch.exp(scores[:num_return_gen])
|
||||
|
||||
yield tokens[:num_return_gen], context_lengths, torch.exp(scores[:num_return_gen])
|
||||
|
||||
output_info = (context_length, done, scores, tokens)
|
||||
return output_info
|
||||
|
||||
@ -293,17 +300,9 @@ def _beam_candidates_at_beginning(args, beam_size, new_scores):
|
||||
return indices, sorted_scores
|
||||
|
||||
|
||||
def _pad_tokens(args, tokens):
|
||||
def _pad_tokens(args, tokens, beam_size, num_return_gen):
|
||||
tokens, lengths = pad_batch(tokens, args)
|
||||
tokens = get_accelerator().LongTensor(tokens)
|
||||
prompt_length = min(lengths)
|
||||
if args.text_generation_config['max_new_tokens'] > 0:
|
||||
final_sequence_length = prompt_length + args.text_generation_config['max_new_tokens']
|
||||
else:
|
||||
final_sequence_length = args.text_generation_config['max_length']
|
||||
final_sequence_length = min(final_sequence_length, args.max_position_embeddings)
|
||||
# If the context is too big, this happens
|
||||
if prompt_length >= final_sequence_length:
|
||||
raise ValueError("The length of your input text exceeds the maximum. "
|
||||
"Please increase the value of 'max_length'.")
|
||||
return final_sequence_length, prompt_length, tokens
|
||||
prompt_length = lengths.min().item()
|
||||
lengths = lengths.repeat(min(beam_size, num_return_gen)).cpu().numpy().tolist()
|
||||
|
||||
return prompt_length, lengths, tokens
|
||||
|
@ -28,7 +28,7 @@ def get_batch(context_tokens):
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Move to GPU.
|
||||
tokens = context_tokens.view(args.micro_batch_size, -1).contiguous().to(get_accelerator().device_name())
|
||||
tokens = context_tokens.contiguous().to(get_accelerator().device_name())
|
||||
# Get the attention mask and position ids.
|
||||
attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
@ -41,29 +41,45 @@ def get_batch(context_tokens):
|
||||
|
||||
|
||||
def pad_batch(batch, args):
|
||||
max_context_length = get_accelerator().LongTensor([max(len(val) for val in batch)])
|
||||
torch.distributed.all_reduce(max_context_length)
|
||||
max_context_length = torch.div(max_context_length, torch.distributed.get_world_size(), rounding_mode="floor")
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id else tokenizer.eos_token_id
|
||||
context_lengths = [len(val) for val in batch]
|
||||
|
||||
if args.text_generation_config['max_new_tokens'] > 0:
|
||||
max_length = max(context_lengths) + args.text_generation_config['max_new_tokens']
|
||||
max_length = max_context_length[0].item() + args.text_generation_config['max_new_tokens']
|
||||
else:
|
||||
max_length = args.text_generation_config['max_length']
|
||||
|
||||
# set fused_operator_contiguous_num = 32
|
||||
max_length = math.ceil(max_length / 32) * 32
|
||||
max_length_padded = math.ceil(max_length / 32) * 32
|
||||
|
||||
for i, tokens in enumerate(batch):
|
||||
if context_lengths[i] < max_length:
|
||||
tokens.extend([pad_id] * (max_length - context_lengths[i]))
|
||||
if context_lengths[i] < max_length_padded:
|
||||
tokens.extend([pad_id] * (max_length_padded - context_lengths[i]))
|
||||
|
||||
return batch, context_lengths
|
||||
context_tokens_tensor = get_accelerator().LongTensor(batch)
|
||||
context_length_tensor = get_accelerator().LongTensor(context_lengths)
|
||||
|
||||
torch.distributed.broadcast(context_length_tensor, args.master_rank)
|
||||
torch.distributed.broadcast(context_tokens_tensor, args.master_rank)
|
||||
|
||||
args.seq_length = context_tokens_tensor.shape[1]
|
||||
args.max_position_embeddings = args.seq_length
|
||||
args.max_length_ori = max_length
|
||||
|
||||
return context_tokens_tensor, context_length_tensor
|
||||
|
||||
|
||||
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
|
||||
"""
|
||||
This function has been mostly taken from huggingface conversational ai code
|
||||
This function has been mostly taken from huggingface conversational ai code at
|
||||
https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer
|
||||
-learning-2d818ac26313
|
||||
"""
|
||||
|
||||
if top_k > 0:
|
||||
@ -98,16 +114,7 @@ def greedy_search_or_sampling(model, context_tokens, model_latencies=None, singl
|
||||
model_latencies = [] if model_latencies is None else model_latencies
|
||||
single_token_latency = [] if single_token_latency is None else single_token_latency
|
||||
|
||||
context_tokens, context_lengths = pad_batch(context_tokens, args)
|
||||
context_tokens_tensor = get_accelerator().LongTensor(context_tokens)
|
||||
context_length_tensor = get_accelerator().LongTensor(context_lengths)
|
||||
|
||||
torch.distributed.broadcast(context_length_tensor,
|
||||
parallel_state.get_tensor_model_parallel_src_rank(),
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
torch.distributed.broadcast(context_tokens_tensor,
|
||||
parallel_state.get_tensor_model_parallel_src_rank(),
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
context_tokens_tensor, context_length_tensor = pad_batch(context_tokens, args)
|
||||
|
||||
context_length = context_length_tensor.min().item()
|
||||
|
||||
@ -118,13 +125,17 @@ def greedy_search_or_sampling(model, context_tokens, model_latencies=None, singl
|
||||
model_latencies=model_latencies
|
||||
)
|
||||
|
||||
count = 0
|
||||
|
||||
yield from _post_process(batch_token_iterator, context_length, count, single_token_latency)
|
||||
yield from _post_process(
|
||||
batch_token_iterator,
|
||||
context_length,
|
||||
context_length_tensor,
|
||||
single_token_latency
|
||||
)
|
||||
|
||||
|
||||
def _post_process(batch_token_iterator, context_length, count, single_token_latency):
|
||||
def _post_process(batch_token_iterator, context_length, context_lengths, single_token_latency):
|
||||
t0 = time.time()
|
||||
count = 0
|
||||
for tokens, lengths, log_probs in batch_token_iterator:
|
||||
if count > 1:
|
||||
get_accelerator().synchronize()
|
||||
@ -135,13 +146,12 @@ def _post_process(batch_token_iterator, context_length, count, single_token_late
|
||||
count += 1
|
||||
context_length += 1
|
||||
if tokens is not None:
|
||||
yield tokens[:, :context_length], lengths, log_probs
|
||||
yield tokens[:, :context_length], context_lengths.cpu().numpy().tolist(), log_probs
|
||||
else:
|
||||
yield None, None, None
|
||||
|
||||
|
||||
def switch(val1, val2, boolean):
|
||||
|
||||
boolean = boolean.type_as(val1)
|
||||
return (1 - boolean) * val1 + boolean * val2
|
||||
|
||||
@ -165,9 +175,10 @@ def forward_step(model, tokens, **kwargs):
|
||||
t0 = time.time()
|
||||
args = get_args()
|
||||
orig_seq_length = args.seq_length
|
||||
args.seq_length = tokens.shape[1]
|
||||
args.micro_batch_size = tokens.shape[0]
|
||||
config = get_model_config(model)
|
||||
tensor_shapes = [args.seq_length, args.micro_batch_size, args.hidden_size]
|
||||
|
||||
input_tensor = recv_forward(tensor_shapes, config)
|
||||
|
||||
_unwrap_and_set_input_tensor(args, input_tensor, model)
|
||||
@ -241,10 +252,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths, type_ids=None,
|
||||
counter = 0
|
||||
layer_past = None
|
||||
batch_size = tokens.size(0)
|
||||
max_length = tokens.size(1)
|
||||
max_length = args.max_length_ori
|
||||
context_length = context_lengths.min().item()
|
||||
is_done = torch.zeros([batch_size]).byte().to(get_accelerator().device_name())
|
||||
lengths = torch.ones([batch_size]).long().to(get_accelerator().device_name()) * max_length
|
||||
|
||||
while context_length < max_length:
|
||||
if args.text_generation_config['recompute']:
|
||||
@ -277,7 +287,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths, type_ids=None,
|
||||
|
||||
done = _is_done(is_done, prev, started, tokenizer)
|
||||
|
||||
yield tokens, lengths, output_log_probs
|
||||
yield tokens, max_length, output_log_probs
|
||||
|
||||
context_length += 1
|
||||
counter += 1
|
||||
|
@ -33,5 +33,5 @@ python -m torch.distributed.launch $DISTRIBUTED_ARGS ./tasks/inference/inference
|
||||
--micro-batch-size 1 \
|
||||
--seq-length 1024 \
|
||||
--max-new-tokens 256 \
|
||||
--seed 42
|
||||
--seed 42 \
|
||||
--position-embedding-type alibi \
|
||||
|
@ -62,7 +62,7 @@ def task_do_sample(args, model, tokenizer=None, system_template="", dialog_templ
|
||||
|
||||
t = time.time()
|
||||
output = model.generate(
|
||||
instruction,
|
||||
[instruction, instruction],
|
||||
do_sample=True,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
|
Loading…
Reference in New Issue
Block a user