mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
目前llama2 model均使用llama model相关代码,删除llama2无关的代码。
This commit is contained in:
parent
63ec3c7870
commit
6fe54d95d5
@ -1,924 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import deepspeed
|
||||
from deepspeed.pipe import PipelineModule, LayerSpec
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed.core import tensor_parallel, parallel_state, utils
|
||||
from ascendspeed.model.module import MegatronModule, MegatronModuleForCausalLM, float16_to_fp32, fp32_to_float16
|
||||
from ascendspeed.core.enums import AttnMaskType, AttnType
|
||||
from ascendspeed.model.utils import init_method_normal, scaled_init_method_normal, attention_mask_func
|
||||
from ascendspeed.core.tensor_parallel.mappings import scatter_to_sequence_parallel_region
|
||||
from ascendspeed.model.fused_softmax import NPUFusedScaleMaskSoftmax
|
||||
from ascendspeed.model.language_model import Pooler
|
||||
from ascendspeed.error_utils import check_equal, check_divisible
|
||||
|
||||
|
||||
class RotaryEmbedding(torch.nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
||||
super().__init__()
|
||||
inv_freq = 1.0 / (torch.tensor(base).double() ** (torch.arange(0, dim, 2).float().to(device) / dim).double())
|
||||
self.register_buffer("inv_freq", inv_freq)
|
||||
|
||||
# Build here to make `torch.jit.trace` work.
|
||||
self.max_seq_len_cached = max_position_embeddings
|
||||
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
|
||||
def forward(self, x, seq_len=None):
|
||||
# x: [bs, num_attention_heads, seq_len, head_size]
|
||||
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
|
||||
if seq_len > self.max_seq_len_cached:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
|
||||
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
||||
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
||||
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
|
||||
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
|
||||
return (
|
||||
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module): # for cpu
|
||||
def __init__(self, hidden_size, eps=1e-6, sequence_parallel=False):
|
||||
"""
|
||||
LlamaRMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
setattr(self.weight, 'sequence_parallel', sequence_parallel)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon).half()
|
||||
hidden_states = self.weight * hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama2LMHead(MegatronModule):
|
||||
"""Causal LM head for Llama
|
||||
|
||||
Arguments:
|
||||
vocab_size: size of vocabulary.
|
||||
hidden_size: hidden size
|
||||
gather_output: wether output logits being gathered or not.
|
||||
init_method: init method for weight initialization
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
init_method,
|
||||
parallel_output=True):
|
||||
super(Llama2LMHead, self).__init__()
|
||||
args = get_args()
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
self.parallel_output = parallel_output
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
self.lm_head = tensor_parallel.ColumnParallelLinear(input_size=self.hidden_size,
|
||||
output_size=vocab_size,
|
||||
config=config,
|
||||
bias=False,
|
||||
gather_output=not self.parallel_output,
|
||||
skip_bias_add=True,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = inputs.transpose(0, 1).contiguous() if self.sequence_parallel else inputs
|
||||
logits, _ = self.lm_head(inputs)
|
||||
logits = logits.transpose(0, 1).contiguous() if self.sequence_parallel else logits
|
||||
return logits
|
||||
|
||||
|
||||
class Llama2LMHeadPipe(Llama2LMHead):
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
|
||||
if isinstance(inputs, tuple):
|
||||
hidden_states = inputs[0]
|
||||
else:
|
||||
hidden_states = inputs
|
||||
|
||||
if not hasattr(self, '_args'):
|
||||
self._args = get_args()
|
||||
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = inputs[1]
|
||||
|
||||
logits = super().forward(hidden_states)
|
||||
|
||||
# If cmd args has attn_mask, we don't forward it as an activation.
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
return logits
|
||||
else:
|
||||
return logits, attention_mask
|
||||
|
||||
|
||||
class Llama2Embedding(MegatronModule):
|
||||
"""Language model embeddings.
|
||||
|
||||
Arguments:
|
||||
hidden_size: hidden size
|
||||
vocab_size: vocabulary size
|
||||
init_method: weight initialization method
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
hidden_size,
|
||||
vocab_size,
|
||||
init_method):
|
||||
super(Llama2Embedding, self).__init__()
|
||||
args = get_args()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.init_method = init_method
|
||||
|
||||
# Word embeddings (parallel).
|
||||
self.word_embeddings = tensor_parallel.VocabParallelEmbedding(vocab_size, self.hidden_size,
|
||||
init_method=self.init_method, config=config)
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
|
||||
def forward(self, input_ids):
|
||||
# Embeddings.
|
||||
embeddings = self.word_embeddings(input_ids)
|
||||
if self.sequence_parallel:
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
embeddings = scatter_to_sequence_parallel_region(embeddings)
|
||||
embeddings = embeddings.transpose(0, 1).contiguous()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class Llama2EmbeddingPipe(Llama2Embedding):
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
|
||||
if isinstance(inputs, tuple):
|
||||
input_ids = inputs[0]
|
||||
else:
|
||||
input_ids = inputs
|
||||
|
||||
if not hasattr(self, '_args'):
|
||||
self._args = get_args()
|
||||
|
||||
if hasattr(self._args, 'attn_mask'):
|
||||
attention_mask = None
|
||||
else:
|
||||
attention_mask = inputs[-1]
|
||||
|
||||
embeddings = super().forward(input_ids)
|
||||
# If cmd args has attn_mask, we don't forward it as an activation.
|
||||
if not hasattr(self._args, 'attn_mask'):
|
||||
setattr(self._args, 'attn_mask', attention_mask)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class Llama2ParallelMLP(MegatronModule):
|
||||
"""MLP.
|
||||
|
||||
MLP will take the input with h hidden state, project it to intermediate
|
||||
hidden dimension, perform nonlinear transformation, and project the
|
||||
state back into h hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, config, init_method, output_layer_init_method, moe=False,
|
||||
enable_expert_tensor_parallelism=False):
|
||||
self.mlp = Llama2ParallelMLP
|
||||
super(Llama2ParallelMLP, self).__init__()
|
||||
args = get_args()
|
||||
self.init_method = init_method
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
|
||||
# Project to intermediate.
|
||||
self.gate_proj = tensor_parallel.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.ffn_hidden_size,
|
||||
config=config,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=self.init_method,
|
||||
skip_bias_add=True,
|
||||
moe=moe,
|
||||
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism,
|
||||
)
|
||||
|
||||
self.up_proj = tensor_parallel.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
args.ffn_hidden_size,
|
||||
config=config,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=self.init_method,
|
||||
skip_bias_add=True,
|
||||
moe=moe,
|
||||
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism,
|
||||
)
|
||||
|
||||
self.activation_func = F.silu
|
||||
|
||||
# Project back to h.
|
||||
self.down_proj = tensor_parallel.RowParallelLinear(
|
||||
args.ffn_hidden_size,
|
||||
args.hidden_size,
|
||||
config=config,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
init_method=self.output_layer_init_method,
|
||||
skip_bias_add=True,
|
||||
moe=moe,
|
||||
enable_expert_tensor_parallelism=enable_expert_tensor_parallelism,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
intermediate_parallel = self.activation_func(self.gate_proj(hidden_states)[0]) * self.up_proj(hidden_states)[0]
|
||||
|
||||
output, _ = self.down_proj(intermediate_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class Llama2ParallelAttention(MegatronModule):
|
||||
"""Parallel self-attention layer abstract class.
|
||||
|
||||
Self-attention layer takes input with size [b, s, h]
|
||||
and returns output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, config, init_method,
|
||||
output_layer_init_method, layer_number,
|
||||
attention_type=AttnType.self_attn,
|
||||
attn_mask_type=AttnMaskType.causal):
|
||||
super(Llama2ParallelAttention, self).__init__()
|
||||
|
||||
check_equal(attention_type, AttnType.self_attn)
|
||||
check_equal(attn_mask_type, AttnMaskType.causal)
|
||||
|
||||
args = get_args()
|
||||
self.fp16 = args.fp16
|
||||
self.bf16 = args.bf16
|
||||
self.sequence_parallel = args.sequence_parallel
|
||||
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32
|
||||
if self.apply_query_key_layer_scaling:
|
||||
self.attention_softmax_in_fp32 = True
|
||||
self.layer_number = max(1, layer_number)
|
||||
self.attention_type = attention_type
|
||||
self.attn_mask_type = attn_mask_type
|
||||
self.init_method = init_method
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
|
||||
self.num_attention_heads = args.num_attention_heads
|
||||
projection_size = args.kv_channels * args.num_attention_heads
|
||||
|
||||
# Per attention head and per partition values.
|
||||
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
||||
self.hidden_size_per_partition = utils.divide(projection_size,
|
||||
world_size)
|
||||
self.hidden_size_per_attention_head = utils.divide(
|
||||
projection_size, args.num_attention_heads)
|
||||
self.num_attention_heads_per_partition = utils.divide(
|
||||
args.num_attention_heads, world_size)
|
||||
|
||||
# Strided linear layer.
|
||||
if attention_type == AttnType.self_attn:
|
||||
self.query_key_value = tensor_parallel.ColumnParallelLinear(
|
||||
args.hidden_size,
|
||||
3 * projection_size,
|
||||
config=config,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
init_method=self.init_method,
|
||||
)
|
||||
|
||||
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
|
||||
|
||||
self.scale_mask_softmax = NPUFusedScaleMaskSoftmax(
|
||||
self.fp16, self.bf16,
|
||||
self.attn_mask_type,
|
||||
args.masked_softmax_fusion,
|
||||
attention_mask_func,
|
||||
self.attention_softmax_in_fp32,
|
||||
(1 / self.norm_factor))
|
||||
|
||||
## Rotary Position Embedding
|
||||
self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head)
|
||||
|
||||
# Output.
|
||||
self.dense = tensor_parallel.RowParallelLinear(
|
||||
projection_size,
|
||||
args.hidden_size,
|
||||
config=config,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
init_method=self.output_layer_init_method,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
|
||||
if deepspeed.checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker, checkpoint
|
||||
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
checkpoint = deepspeed.checkpointing.checkpoint
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None,
|
||||
get_key_value=False):
|
||||
# hidden_states: [sq, b, h]
|
||||
|
||||
# =====================
|
||||
# Query, Key, and Value
|
||||
# =====================
|
||||
|
||||
if self.attention_type == AttnType.self_attn:
|
||||
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
|
||||
mixed_x_layer, _ = self.query_key_value(hidden_states)
|
||||
|
||||
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
||||
(self.num_attention_heads_per_partition,
|
||||
3 * self.hidden_size_per_attention_head)
|
||||
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
||||
|
||||
# [sq, b, 3 * h] --> 3 [sq, b, h]
|
||||
(query_layer,
|
||||
key_layer,
|
||||
value_layer) = tensor_parallel.utils.split_tensor_along_last_dim(mixed_x_layer, 3)
|
||||
|
||||
# ==================================
|
||||
# Rotary Position Embedding
|
||||
# ==================================
|
||||
# [sq, b, np, hn] --> [b, np, sq, hn] TODO optimize the permute of dimension back and forth
|
||||
query_layer = query_layer.permute(1, 2, 0, 3).contiguous()
|
||||
key_layer = key_layer.permute(1, 2, 0, 3).contiguous()
|
||||
value_layer = value_layer.permute(1, 2, 0, 3).contiguous()
|
||||
|
||||
cos, sin = self.rotary_emb(value_layer, seq_len=new_tensor_shape[0])
|
||||
query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, offset=0)
|
||||
|
||||
|
||||
# ==================================
|
||||
# Adjust key and value for inference
|
||||
# ==================================
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
key_layer = torch.cat((past_key.type_as(key_layer),
|
||||
key_layer), dim=0)
|
||||
value_layer = torch.cat((past_value.type_as(value_layer),
|
||||
value_layer), dim=0)
|
||||
if get_key_value:
|
||||
present = (key_layer, value_layer)
|
||||
|
||||
# ===================================
|
||||
# Raw attention scores. [b, np, s, s]
|
||||
# ===================================
|
||||
|
||||
attention_scores = torch.matmul(query_layer, key_layer.transpose(3, 2))
|
||||
|
||||
# ==================================================
|
||||
# Update attention mask for inference. [b, np, sq, sk]
|
||||
# ==================================================
|
||||
|
||||
if get_key_value:
|
||||
with torch.no_grad():
|
||||
if layer_past is not None:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
attention_scores.size(3) - 1,
|
||||
:attention_scores.size(3)].unsqueeze(2)
|
||||
else:
|
||||
attention_mask = attention_mask[
|
||||
...,
|
||||
:attention_scores.size(3),
|
||||
:attention_scores.size(3)]
|
||||
|
||||
# ===========================
|
||||
# Attention probs and dropout
|
||||
# ===========================
|
||||
|
||||
# attention scores and attention mask [b, np, sq, sk]
|
||||
attention_probs = self.scale_mask_softmax(attention_scores,
|
||||
attention_mask)
|
||||
|
||||
# =========================
|
||||
# Context layer. [sq, b, hp]
|
||||
# =========================
|
||||
context_layer = torch.matmul(attention_probs, value_layer)
|
||||
|
||||
bs, nh, sq, hd = context_layer.shape
|
||||
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
|
||||
context_layer = context_layer.view(sq, bs, nh*hd)
|
||||
|
||||
output, _ = self.dense(context_layer)
|
||||
|
||||
if get_key_value:
|
||||
output = [output, present]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class Llama2ParallelTransformerLayer(MegatronModule):
|
||||
"""A single transformer layer.
|
||||
|
||||
Transformer layer takes input with size [b, s, h] and returns an
|
||||
output of the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, config, init_method, output_layer_init_method,
|
||||
layer_number,
|
||||
self_attn_mask_type=AttnMaskType.causal):
|
||||
args = get_args()
|
||||
|
||||
super(Llama2ParallelTransformerLayer, self).__init__()
|
||||
self.layer_number = layer_number
|
||||
check_equal(self_attn_mask_type, AttnMaskType.causal)
|
||||
|
||||
self.bf16 = args.bf16
|
||||
self.fp32_residual_connection = args.fp32_residual_connection
|
||||
self.init_method = init_method
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
|
||||
# Layernorm on the input data.
|
||||
self.input_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
# Self attention.
|
||||
self.attention = Llama2ParallelAttention(
|
||||
self.config,
|
||||
self.init_method,
|
||||
self.output_layer_init_method,
|
||||
layer_number,
|
||||
attn_mask_type=self_attn_mask_type)
|
||||
|
||||
# Layernorm on the attention output
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
# MLP
|
||||
self.rank = args.rank
|
||||
self.mlp = Llama2ParallelMLP(config, self.init_method, self.output_layer_init_method)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None,
|
||||
layer_past=None, get_key_value=False):
|
||||
# hidden_states: [b, s, h]
|
||||
residual = hidden_states
|
||||
# Layer norm at the beginning of the transformer layer.
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
# Self attention.
|
||||
hidden_states = self.attention(hidden_states,
|
||||
attention_mask,
|
||||
layer_past=layer_past,
|
||||
get_key_value=get_key_value)
|
||||
|
||||
if get_key_value:
|
||||
hidden_states, presents = hidden_states
|
||||
|
||||
# Residual connection.
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
|
||||
# Layer norm post the self attention.
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
# MLP.
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
if get_key_value:
|
||||
hidden_states = [hidden_states, presents]
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Llama2ParallelTransformerLayerPipe(Llama2ParallelTransformerLayer):
|
||||
"""Extends ParallelTransformerLayer to forward attention_mask through the pipeline.
|
||||
|
||||
Forward has two usages that affect attention mask communication:
|
||||
|
||||
1) forward((input, attn_mask) , **kwargs) -> (output, mask)
|
||||
When the attention mask is provided as the second positional
|
||||
argument, typical pipeline behavior is used and both the output
|
||||
*and* mask are returned in a tuple. This tuple is then forwarded
|
||||
to the next stage in the pipeline.
|
||||
|
||||
This version is useful if masks are dynamic.
|
||||
|
||||
2) forward(input, **kwargs) -> output
|
||||
When the mask is static over all samples, it is advantageous to
|
||||
cache the mask and avoid communicating it.
|
||||
|
||||
If no mask is provided, the module will query `self._args.attn_mask`
|
||||
for the mask and only return `super().forward(...)`
|
||||
"""
|
||||
|
||||
def forward(self, inputs, **kwargs):
|
||||
assert torch.is_tensor(inputs) or isinstance(inputs, tuple)
|
||||
if torch.is_tensor(inputs) or len(inputs) == 1:
|
||||
# No attention mask forwarded, search for args.attn_mask
|
||||
if not hasattr(self, '_args'):
|
||||
self._args = get_args()
|
||||
hidden_states, attention_mask = inputs, self._args.attn_mask
|
||||
return super().forward(hidden_states, attention_mask, **kwargs)
|
||||
elif len(inputs) == 2:
|
||||
# Attention mask is an activation.
|
||||
hidden_states, attention_mask = inputs[0], inputs[1]
|
||||
return super().forward(*inputs, **kwargs), attention_mask
|
||||
else:
|
||||
raise RuntimeError('Received more inputs than understood.')
|
||||
|
||||
|
||||
class Llama2ParallelTransformer(MegatronModule):
|
||||
"""Transformer class."""
|
||||
|
||||
def __init__(self, config, init_method, output_layer_init_method,
|
||||
self_attn_mask_type=AttnMaskType.causal,
|
||||
pre_process=True, post_process=True):
|
||||
|
||||
super(Llama2ParallelTransformer, self).__init__()
|
||||
args = get_args()
|
||||
check_equal(self_attn_mask_type, AttnMaskType.causal)
|
||||
self.config = config
|
||||
self.bf16 = args.bf16
|
||||
self.fp32_residual_connection = args.fp32_residual_connection
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.input_tensor = None
|
||||
self.ds_inference = args.ds_inference
|
||||
self.init_method = init_method
|
||||
self.output_layer_init_method = output_layer_init_method
|
||||
|
||||
# Store activation checkpoiting flag.
|
||||
self.checkpoint_activations = args.checkpoint_activations
|
||||
self.checkpoint_num_layers = args.checkpoint_num_layers
|
||||
self.distribute_saved_activations = \
|
||||
config.distribute_saved_activations and not config.sequence_parallel
|
||||
|
||||
# Number of layers.
|
||||
error_info = 'num_layers must be divisible by pipeline_model_parallel_size'
|
||||
check_divisible(args.num_layers, parallel_state.get_pipeline_model_parallel_world_size(), error_info)
|
||||
self.num_layers = args.num_layers // parallel_state.get_pipeline_model_parallel_world_size()
|
||||
|
||||
# Transformer layers.
|
||||
def build_layer(layer_number):
|
||||
return Llama2ParallelTransformerLayer(
|
||||
config,
|
||||
self.init_method,
|
||||
self.output_layer_init_method,
|
||||
layer_number)
|
||||
|
||||
if args.virtual_pipeline_model_parallel_size is not None:
|
||||
error_info = 'num_layers_per_stage must be divisible by ' \
|
||||
'virtual_pipeline_model_parallel_size'
|
||||
check_divisible(args.num_layers, args.virtual_pipeline_model_parallel_size, error_info)
|
||||
# Number of layers in each model chunk is the number of layers in the stage,
|
||||
# divided by the number of model chunks in a stage.
|
||||
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size
|
||||
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
|
||||
# layers to stages like (each list is a model chunk):
|
||||
# Stage 0: [0] [2] [4] [6]
|
||||
# Stage 1: [1] [3] [5] [7]
|
||||
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
|
||||
# layers to stages like (each list is a model chunk):
|
||||
# Stage 0: [0, 1] [4, 5]
|
||||
# Stage 1: [2, 3] [6, 7]
|
||||
offset = parallel_state.get_virtual_pipeline_model_parallel_rank() * (
|
||||
args.num_layers // args.virtual_pipeline_model_parallel_size) + \
|
||||
(parallel_state.get_pipeline_model_parallel_rank() * self.num_layers)
|
||||
else:
|
||||
# Each stage gets a contiguous set of layers.
|
||||
offset = parallel_state.get_pipeline_model_parallel_rank() * self.num_layers
|
||||
|
||||
self.layers = []
|
||||
# Build the layers
|
||||
for i in range(self.num_layers):
|
||||
layer_num = i + 1 + offset
|
||||
self.layers.append(build_layer(layer_num))
|
||||
|
||||
self.layers = torch.nn.ModuleList(self.layers)
|
||||
|
||||
if self.post_process:
|
||||
# Final layer norm before output.
|
||||
self.final_layernorm = RMSNorm(
|
||||
args.hidden_size,
|
||||
eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel)
|
||||
|
||||
if deepspeed.checkpointing.is_configured():
|
||||
global get_cuda_rng_tracker, checkpoint
|
||||
get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
|
||||
checkpoint = deepspeed.checkpointing.checkpoint
|
||||
|
||||
def _get_layer(self, layer_number):
|
||||
return self.layers[layer_number]
|
||||
|
||||
def _checkpointed_forward(self, hidden_states, attention_mask):
|
||||
"""Forward method with activation checkpointing."""
|
||||
|
||||
def custom(start, end):
|
||||
def custom_forward(*inputs):
|
||||
x_ = inputs[0]
|
||||
attention_mask = inputs[1]
|
||||
for index in range(start, end):
|
||||
layer = self._get_layer(index)
|
||||
x_ = layer(x_, attention_mask=attention_mask)
|
||||
return x_
|
||||
|
||||
return custom_forward
|
||||
|
||||
# Make sure memory is freed.
|
||||
tensor_parallel.reset_checkpointed_activations_memory_buffer()
|
||||
l = 0
|
||||
while l < self.num_layers:
|
||||
hidden_states = tensor_parallel.checkpoint(
|
||||
custom(l, l + self.checkpoint_num_layers),
|
||||
self.distribute_saved_activations,
|
||||
hidden_states, attention_mask)
|
||||
l += self.checkpoint_num_layers
|
||||
|
||||
return hidden_states
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""Set input tensor to be used instead of forward()'s input.
|
||||
|
||||
When doing pipeline parallelism the input from the previous
|
||||
stage comes from communication, not from the input, so the
|
||||
model's forward_step_func won't have it. This function is thus
|
||||
used by internal code to bypass the input provided by the
|
||||
forward_step_func"""
|
||||
if isinstance(input_tensor, (list, tuple)):
|
||||
self.input_tensor = input_tensor[0]
|
||||
else:
|
||||
self.input_tensor = input_tensor
|
||||
|
||||
def forward(self, hidden_states, attention_mask, layer_past=None, get_key_value=False):
|
||||
|
||||
# Checks.
|
||||
if layer_past is not None:
|
||||
assert get_key_value, \
|
||||
'for not None values in layer_past, ' \
|
||||
'expected get_key_value to be set'
|
||||
if get_key_value:
|
||||
assert not self.checkpoint_activations, \
|
||||
'get_key_value does not work with ' \
|
||||
'activation checkpointing'
|
||||
|
||||
# Reza's note: DeepSpeed inference does not support transposes
|
||||
if not self.ds_inference:
|
||||
if self.pre_process:
|
||||
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
|
||||
# If the input flag for fp32 residual connection is set, convert for float.
|
||||
if self.fp32_residual_connection:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous().float()
|
||||
# Otherwise, leave it as is.
|
||||
else:
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
else:
|
||||
# See set_input_tensor()
|
||||
hidden_states = self.input_tensor
|
||||
|
||||
if self.checkpoint_activations:
|
||||
hidden_states = self._checkpointed_forward(hidden_states, attention_mask)
|
||||
else:
|
||||
if get_key_value:
|
||||
presents = []
|
||||
for index in range(self.num_layers):
|
||||
layer = self._get_layer(index)
|
||||
past = None
|
||||
if layer_past is not None:
|
||||
past = layer_past[index]
|
||||
hidden_states = layer(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
layer_past=past,
|
||||
get_key_value=get_key_value)
|
||||
if get_key_value:
|
||||
hidden_states, present = hidden_states
|
||||
presents.append(present)
|
||||
|
||||
# Final layer norm.
|
||||
if self.post_process:
|
||||
if not self.ds_inference:
|
||||
# Reverting data format change [s b h] --> [b s h].
|
||||
hidden_states = hidden_states.transpose(0, 1).contiguous()
|
||||
output = self.final_layernorm(hidden_states)
|
||||
else:
|
||||
output = hidden_states
|
||||
if get_key_value:
|
||||
output = [output, presents]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def CrossEntropy(output, labels):
|
||||
labels, loss_mask = labels[0], labels[1]
|
||||
|
||||
args = get_args()
|
||||
losses = tensor_parallel.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
|
||||
loss_mask = loss_mask.view(-1)
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class Llama2ModelPipe(PipelineModule, MegatronModule, MegatronModuleForCausalLM):
|
||||
"""llama Language model."""
|
||||
|
||||
def __init__(self, config, parallel_output=True):
|
||||
args = get_args()
|
||||
|
||||
self.init_method = init_method_normal(args.init_method_std)
|
||||
self.output_layer_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
|
||||
self.parallel_output = parallel_output
|
||||
|
||||
self.specs = []
|
||||
|
||||
def _to_float16(inputs):
|
||||
if args.fp16:
|
||||
return fp32_to_float16(inputs, lambda v: v.half())
|
||||
elif args.bf16:
|
||||
return fp32_to_float16(inputs, lambda v: v.bfloat16())
|
||||
else:
|
||||
return inputs
|
||||
|
||||
self.specs.append(_to_float16)
|
||||
|
||||
# Embedding layer
|
||||
self.specs.append(LayerSpec(Llama2EmbeddingPipe, config=config, hidden_size=args.hidden_size, vocab_size=args.padded_vocab_size,
|
||||
init_method=self.init_method, ))
|
||||
|
||||
if args.fp32_residual_connection:
|
||||
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
|
||||
else:
|
||||
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
|
||||
|
||||
for layer_idx in range(args.num_layers):
|
||||
self.specs.append(
|
||||
LayerSpec(Llama2ParallelTransformerLayerPipe,
|
||||
config=config,
|
||||
init_method=self.init_method,
|
||||
output_layer_init_method=self.output_layer_init_method,
|
||||
layer_number=layer_idx))
|
||||
|
||||
# Undo data format change
|
||||
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
|
||||
|
||||
# Final layernorm after transformer layers
|
||||
self.specs.append(
|
||||
LayerSpec(RMSNorm, args.hidden_size, eps=args.layernorm_epsilon,
|
||||
sequence_parallel=args.sequence_parallel))
|
||||
|
||||
self.specs.append(
|
||||
LayerSpec(Llama2LMHeadPipe, config=config, hidden_size=args.hidden_size, vocab_size=args.padded_vocab_size,
|
||||
init_method=self.init_method, parallel_output=self.parallel_output)
|
||||
)
|
||||
|
||||
# Convert to fp32 if needed
|
||||
if args.fp16 or args.bf16:
|
||||
self.specs.append(float16_to_fp32)
|
||||
|
||||
if args.checkpoint_activations:
|
||||
interval = args.checkpoint_num_layers
|
||||
else:
|
||||
interval = 0
|
||||
|
||||
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
|
||||
topo = PipeModelDataParallelTopology(num_pp=parallel_state.get_pipeline_model_parallel_world_size(),
|
||||
num_mp=parallel_state.get_tensor_model_parallel_world_size(),
|
||||
num_dp=parallel_state.get_data_parallel_world_size())
|
||||
|
||||
super().__init__(layers=self.specs,
|
||||
loss_fn=CrossEntropy,
|
||||
topology=topo,
|
||||
activation_checkpoint_interval=interval,
|
||||
partition_method='type:transformer')
|
||||
|
||||
|
||||
class Llama2Model(MegatronModule, MegatronModuleForCausalLM):
|
||||
"""llaMA2 Language model."""
|
||||
|
||||
def __init__(self, config, pre_process, post_process, parallel_output=True, add_pooler=False, **kwargs):
|
||||
super(Llama2Model, self).__init__(config=config, share_word_embeddings=False)
|
||||
args = get_args()
|
||||
self.config = config,
|
||||
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
|
||||
self.hidden_size = args.hidden_size
|
||||
self.pre_process = pre_process
|
||||
self.post_process = post_process
|
||||
self.parallel_output = parallel_output
|
||||
self.add_pooler = add_pooler
|
||||
self.init_method = init_method_normal(args.init_method_std)
|
||||
self.output_layer_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
|
||||
self.self_attn_mask_type = AttnMaskType.causal
|
||||
self.padded_vocab_size = args.padded_vocab_size
|
||||
|
||||
if self.pre_process:
|
||||
self.embedding = Llama2Embedding(config=config,
|
||||
hidden_size=args.hidden_size,
|
||||
init_method=self.init_method,
|
||||
vocab_size=self.padded_vocab_size)
|
||||
|
||||
# Transformer.
|
||||
self.language_model = Llama2ParallelTransformer(
|
||||
self.config,
|
||||
self.init_method,
|
||||
self.output_layer_init_method,
|
||||
self_attn_mask_type=self.self_attn_mask_type,
|
||||
pre_process=self.pre_process,
|
||||
post_process=self.post_process,
|
||||
)
|
||||
|
||||
if self.post_process:
|
||||
# Pooler.
|
||||
if self.add_pooler:
|
||||
self.pooler = Pooler(self.hidden_size, self.init_method)
|
||||
|
||||
self.lm_head = Llama2LMHead(config=self.config,
|
||||
hidden_size=args.hidden_size,
|
||||
vocab_size=self.padded_vocab_size,
|
||||
init_method=self.init_method,
|
||||
parallel_output=self.parallel_output)
|
||||
|
||||
def set_input_tensor(self, input_tensor):
|
||||
"""See ascendspeed.model.transformer.set_input_tensor()"""
|
||||
self.language_model.set_input_tensor(input_tensor)
|
||||
|
||||
def forward(self, input_ids, attention_mask, labels=None, layer_past=None, get_key_value=False):
|
||||
args = get_args()
|
||||
|
||||
if self.pre_process:
|
||||
hidden_states = self.embedding(input_ids)
|
||||
else:
|
||||
hidden_states = input_ids
|
||||
|
||||
# decoder
|
||||
hidden_states = self.language_model(hidden_states, attention_mask, layer_past=layer_past,
|
||||
get_key_value=get_key_value)
|
||||
|
||||
if self.post_process:
|
||||
if get_key_value:
|
||||
hidden_states, presents = hidden_states
|
||||
|
||||
if self.add_pooler:
|
||||
hidden_states = self.pooler(hidden_states, pooling_sequence_index)
|
||||
|
||||
hidden_states = self.lm_head(hidden_states)
|
||||
|
||||
if labels is None:
|
||||
if get_key_value:
|
||||
return [hidden_states, presents]
|
||||
return hidden_states
|
||||
|
||||
else:
|
||||
if self.fp16_lm_cross_entropy:
|
||||
check_equal(hidden_states.dtype, torch.half)
|
||||
loss = tensor_parallel.vocab_parallel_cross_entropy(hidden_states, labels)
|
||||
else:
|
||||
loss = tensor_parallel.vocab_parallel_cross_entropy(hidden_states.float(), labels)
|
||||
return loss
|
||||
|
||||
return hidden_states
|
@ -1,56 +0,0 @@
|
||||
# This is an example: training llama using PTD/
|
||||
|
||||
# The number of parameters is not aligned
|
||||
export LD_LIBRARY_PATH=/usr/local/lib:/usr/local/lib:/root/miniconda3/lib:$LD_LIBRARY_PATH
|
||||
export HCCL_CONNECT_TIMEOUT=1200
|
||||
export COMBINED_ENABLE=1
|
||||
|
||||
# Change for multinode config
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6001
|
||||
NNODES=1
|
||||
NODE_RANK=0
|
||||
NPUS_PER_NODE=8
|
||||
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
|
||||
|
||||
DATA_PATH=./dataset/llama_text_document
|
||||
CHECKPOINT_PATH=./ckpt
|
||||
rm -rf ./ckpt/*
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
# Main script
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS \
|
||||
pretrain_llama.py \
|
||||
--DDP-impl local \
|
||||
--tensor-model-parallel-size 4 \
|
||||
--pipeline-model-parallel-size 2 \
|
||||
--num-layers 30 \
|
||||
--hidden-size 6656 \
|
||||
--ffn-hidden-size 17920 \
|
||||
--num-attention-heads 52 \
|
||||
--micro-batch-size 1 \
|
||||
--global-batch-size 32 \
|
||||
--seq-length 2048 \
|
||||
--max-position-embeddings 4096 \
|
||||
--train-iters 500000 \
|
||||
--lr-decay-iters 320000 \
|
||||
--save $CHECKPOINT_PATH \
|
||||
--load $CHECKPOINT_PATH \
|
||||
--data-path $DATA_PATH \
|
||||
--tokenizer-name-or-path ./dataset/llama/ \
|
||||
--tokenizer-not-use-fast \
|
||||
--data-impl mmap \
|
||||
--split 949,50,1 \
|
||||
--distributed-backend nccl \
|
||||
--lr 0.00015 \
|
||||
--lr-decay-style cosine \
|
||||
--min-lr 1.0e-5 \
|
||||
--weight-decay 1e-2 \
|
||||
--clip-grad 1.0 \
|
||||
--lr-warmup-fraction .01 \
|
||||
--log-interval 1 \
|
||||
--save-interval 10000 \
|
||||
--eval-interval 1000 \
|
||||
--eval-iters 10 \
|
||||
--initial-loss-scale 4096.0 \
|
||||
--fp16 | tee logs/train.log
|
@ -1,239 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pretrain Llama2."""
|
||||
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
|
||||
import deepspeed
|
||||
import deepspeed_npu
|
||||
from deepspeed.runtime.utils import see_memory_usage
|
||||
from deepspeed.accelerator.real_accelerator import get_accelerator
|
||||
|
||||
from ascendspeed import get_args
|
||||
from ascendspeed import print_rank_0
|
||||
from ascendspeed import get_timers
|
||||
from ascendspeed import get_tokenizer
|
||||
from ascendspeed.core import parallel_state, tensor_parallel
|
||||
from ascendspeed.data.gpt_dataset import build_train_valid_test_datasets
|
||||
from ascendspeed.model.llama2_model import Llama2Model, Llama2ModelPipe
|
||||
from ascendspeed.training import pretrain
|
||||
from ascendspeed.utils import get_ltor_masks_and_position_ids
|
||||
from ascendspeed.utils import average_losses_across_data_parallel_group
|
||||
from ascendspeed.arguments import core_transformer_config_from_args
|
||||
from ascendspeed.core.enums import ModelType
|
||||
|
||||
|
||||
def model_provider(pre_process=True, post_process=True):
|
||||
"""Build the model."""
|
||||
|
||||
print_rank_0('Building llama model ...')
|
||||
see_memory_usage(f"Before Building Model ...", force=True)
|
||||
|
||||
args = get_args()
|
||||
config = core_transformer_config_from_args(get_args())
|
||||
with deepspeed.zero.Init(data_parallel_group=parallel_state.get_data_parallel_group(),
|
||||
remote_device=None if args.remote_device == 'none' else args.remote_device,
|
||||
config_dict_or_path=args.deepspeed_config,
|
||||
enabled=args.zero_stage == 3,
|
||||
mpu=parallel_state):
|
||||
if args.deepspeed and not args.no_pipeline_parallel:
|
||||
model = Llama2ModelPipe(config=config, parallel_output=True)
|
||||
# This is a hack to give us a reference to get_batch_pipe from within training.py
|
||||
# We need to call model.set_batch_fn after deepspeed.initialize
|
||||
model._megatron_batch_fn = get_batch_pipe
|
||||
|
||||
# Predompute the attention mask and store it in args. This avoids having to
|
||||
# pipeline it as an activation during training. The mask is constant, and thus
|
||||
# we can reuse it.
|
||||
attention_mask = torch.tril(torch.ones(
|
||||
(1, args.seq_length, args.seq_length),
|
||||
device=get_accelerator().current_device_name())).view(
|
||||
1, 1, args.seq_length, args.seq_length)
|
||||
|
||||
# Convert attention mask to binary:
|
||||
attention_mask = (attention_mask < 0.5)
|
||||
if args.fp16:
|
||||
attention_mask = attention_mask.half()
|
||||
elif args.bf16:
|
||||
attention_mask = attention_mask.bfloat16()
|
||||
|
||||
# Attention mask must be bool.
|
||||
args.attn_mask = attention_mask.to(torch.bool)
|
||||
else:
|
||||
model = Llama2Model(
|
||||
config=config,
|
||||
parallel_output=True,
|
||||
add_pooler=False,
|
||||
pre_process=pre_process,
|
||||
post_process=post_process
|
||||
)
|
||||
see_memory_usage(f"After Building Model", force=True)
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(data_iterator):
|
||||
"""Generate a batch"""
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text']
|
||||
data_type = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
if hasattr(data_iterator, '__next__'):
|
||||
data = next(data_iterator)
|
||||
else:
|
||||
if isinstance(data_iterator, list):
|
||||
return data_iterator.pop(0)
|
||||
else:
|
||||
data = None
|
||||
data_b = tensor_parallel.broadcast_data(keys, data, data_type)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b['text'].long()
|
||||
labels = tokens_[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].contiguous()
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss)
|
||||
|
||||
if args.foldx_mode is not None:
|
||||
if hasattr(data_iterator, 'dummy_iterators'):
|
||||
for iterator in data_iterator.dummy_iterators:
|
||||
iterator.append((tokens, labels, loss_mask, attention_mask,))
|
||||
|
||||
return tokens, labels, loss_mask, attention_mask
|
||||
|
||||
|
||||
def data_post_process(data, data_sampler_state_dict):
|
||||
args = get_args()
|
||||
if args.data_efficiency_curriculum_learning:
|
||||
if 'seqlen_truncate' in data_sampler_state_dict['current_difficulties']:
|
||||
args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_truncate'
|
||||
current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_truncate']
|
||||
if current_seqlen < args.seq_length:
|
||||
data['text'] = data['text'][:, :(current_seqlen + 1)].contiguous()
|
||||
elif 'seqlen_reshape' in data_sampler_state_dict['current_difficulties']:
|
||||
args.data_efficiency_curriculum_learning_seqlen_type = 'seqlen_reshape'
|
||||
current_seqlen = data_sampler_state_dict['current_difficulties']['seqlen_reshape']
|
||||
if current_seqlen < args.seq_length:
|
||||
orig_num_token = torch.numel(data['text'])
|
||||
reshape_len = (data['text'].size()[1] // (current_seqlen + 1)) * (current_seqlen + 1)
|
||||
data['text'] = torch.cat((data['text'][:, :reshape_len].contiguous().view(-1, current_seqlen + 1),
|
||||
data['text'][:, -(current_seqlen + 1):]), 0).contiguous()
|
||||
num_row = math.ceil(orig_num_token / (current_seqlen + 1))
|
||||
num_row = min(num_row, data['text'].size()[0])
|
||||
if num_row > 1 and num_row % 2 != 0:
|
||||
num_row -= 1
|
||||
data['text'] = data['text'][:num_row, :].contiguous()
|
||||
else:
|
||||
args.data_efficiency_curriculum_learning_seqlen_type = None
|
||||
return data
|
||||
|
||||
|
||||
def get_batch_pipe(data):
|
||||
"""Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`"""
|
||||
args = get_args()
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Items and their type.
|
||||
keys = ['text']
|
||||
data_type = torch.int64
|
||||
|
||||
# Broadcast data.
|
||||
data_b = tensor_parallel.broadcast_data(keys, data, data_type)
|
||||
|
||||
# Unpack.
|
||||
tokens_ = data_b['text'].long()
|
||||
labels = tokens_[:, 1:].contiguous()
|
||||
tokens = tokens_[:, :-1].contiguous()
|
||||
|
||||
# Get the masks and postition ids.
|
||||
attention_mask, loss_mask, _ = get_ltor_masks_and_position_ids(
|
||||
tokens,
|
||||
tokenizer.eod,
|
||||
args.reset_position_ids,
|
||||
args.reset_attention_mask,
|
||||
args.eod_mask_loss)
|
||||
return (tokens, attention_mask), (labels, loss_mask)
|
||||
|
||||
|
||||
def loss_func(loss_mask, output_tensor):
|
||||
args = get_args()
|
||||
|
||||
losses = output_tensor.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
||||
|
||||
# Reduce loss for logging.
|
||||
averaged_loss = average_losses_across_data_parallel_group([loss])
|
||||
return loss, {'lm loss': averaged_loss[0]}
|
||||
|
||||
|
||||
def forward_step(data_iterator, model):
|
||||
"""Forward step."""
|
||||
args = get_args()
|
||||
|
||||
timers = get_timers()
|
||||
# Get the batch.
|
||||
if args.foldx_mode is None:
|
||||
timers('batch-generator').start()
|
||||
tokens, labels, loss_mask, attention_mask = get_batch(data_iterator)
|
||||
if args.foldx_mode is None:
|
||||
timers('batch-generator').stop()
|
||||
|
||||
output_tensor = model(tokens, attention_mask, labels=labels)
|
||||
# Output_tensor stores the standard loss, loos_func calculates the total loss.
|
||||
return output_tensor, partial(loss_func, loss_mask)
|
||||
|
||||
|
||||
def train_valid_test_datasets_provider(train_val_test_num_samples):
|
||||
"""Build train, valid, and test datasets."""
|
||||
args = get_args()
|
||||
|
||||
print_rank_0('> building train, validation, and test datasets '
|
||||
'for llama ...')
|
||||
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
|
||||
data_prefix=args.data_path,
|
||||
data_impl=args.data_impl,
|
||||
splits_string=args.split,
|
||||
train_valid_test_num_samples=train_val_test_num_samples,
|
||||
seq_length=args.seq_length,
|
||||
seed=args.seed,
|
||||
skip_warmup=(not args.mmap_warmup))
|
||||
print_rank_0("> finished creating llama2 datasets ...")
|
||||
|
||||
return train_ds, valid_ds, test_ds
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.npu.set_compile_mode(jit_compile=True)
|
||||
pretrain(train_valid_test_datasets_provider,
|
||||
model_provider,
|
||||
ModelType.encoder_or_decoder,
|
||||
forward_step,
|
||||
args_defaults={'tokenizer_type': 'PretrainedFromHF'},
|
||||
data_post_process=data_post_process)
|
Loading…
Reference in New Issue
Block a user