mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-05 05:17:40 +08:00
!330 支持自适应选择重计算-优化重计算策略选择算法,加上UT
Merge pull request !330 from Gou Jicheng/dev_from_master_only_recompute
This commit is contained in:
commit
a5d76b70b1
@ -13,15 +13,20 @@ class GraphSolver:
|
||||
self.total_recompute_cost = 0
|
||||
self.total_forward_cost = 0
|
||||
self.layers_module = None
|
||||
self.transformer_module = None
|
||||
self.recompute_policy = {}
|
||||
self.layers_combination = []
|
||||
self.layer_full_recompute_combination = None
|
||||
self.layer_without_recompute_combination = None
|
||||
self.layer_recompute_one_combination = None
|
||||
|
||||
@staticmethod
|
||||
def print_recompute_op(graph):
|
||||
def get_recompute_op(graph):
|
||||
recompute_nodes = []
|
||||
for node in graph.nodes:
|
||||
if graph.nodes[node]['recompute']:
|
||||
recompute_nodes.append(graph.nodes[node]['name'])
|
||||
print_rank_0(f'recompute nodes = {recompute_nodes}')
|
||||
return recompute_nodes
|
||||
|
||||
@staticmethod
|
||||
def dg_init(no_recompute_layer):
|
||||
@ -48,28 +53,39 @@ class GraphSolver:
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
return recompute_policy_tensor.cpu().numpy().tolist()
|
||||
|
||||
@staticmethod
|
||||
def apply_policy_to_model(recompute_policy_new, full_model):
|
||||
if recompute_policy_new[1] == 0:
|
||||
for idx, module in enumerate(full_model):
|
||||
if idx < recompute_policy_new[0]:
|
||||
module['recompute'] = True
|
||||
else:
|
||||
module['recompute'] = False
|
||||
idx_submodule = 0
|
||||
for layer in module['layers']:
|
||||
if recompute_policy_new[idx_submodule + 2] == 1:
|
||||
layer['recompute'] = True
|
||||
idx_submodule += 1
|
||||
else:
|
||||
for idx, module in enumerate(full_model):
|
||||
if idx < recompute_policy_new[0]:
|
||||
module['recompute'] = False
|
||||
idx_submodule = 0
|
||||
for layer in module['layers']:
|
||||
if recompute_policy_new[idx_submodule + 2] == 1:
|
||||
layer['recompute'] = True
|
||||
idx_submodule += 1
|
||||
def set_recompute_info_to_module(self, module, recompute_nodes, recompute):
|
||||
if not recompute:
|
||||
module["recompute"] = False
|
||||
for layer in module["layers"]:
|
||||
layer["recompute"] = False
|
||||
return
|
||||
if len(recompute_nodes) == 0:
|
||||
module["recompute"] = True
|
||||
return
|
||||
sub_modules = module["layers"]
|
||||
recompute_nodes_length = len(recompute_nodes)
|
||||
for i in range(recompute_nodes_length):
|
||||
if recompute_nodes[i] == self.layer_recompute_one_combination.broadcast_value:
|
||||
sub_modules[i]["recompute"] = True
|
||||
continue
|
||||
sub_modules[i]["recompute"] = False
|
||||
|
||||
def apply_policy_to_model(self, recompute_policy_list):
|
||||
full_layers = self.layers_module["layers"]
|
||||
if len(recompute_policy_list) == 0:
|
||||
return
|
||||
idx = 0
|
||||
for policy in recompute_policy_list:
|
||||
n = policy[0]
|
||||
recompute = False
|
||||
recompute_nodes = []
|
||||
if policy[1] != self.layer_without_recompute_combination.broadcast_value:
|
||||
recompute = True
|
||||
if policy[1] == self.layer_recompute_one_combination.broadcast_value:
|
||||
recompute_nodes = policy[2:]
|
||||
for i in range(idx, idx + n):
|
||||
self.set_recompute_info_to_module(full_layers[i], recompute_nodes, recompute)
|
||||
idx += n
|
||||
|
||||
# minimize the number of memory, results in all recompute
|
||||
def calculate_cost_mem(self, g: nx.DiGraph, idx):
|
||||
@ -115,98 +131,220 @@ class GraphSolver:
|
||||
successor_cnt += 1
|
||||
return global_max_mem
|
||||
|
||||
def cal_transformer_memory(self, model_layers):
|
||||
s = 0
|
||||
if 'layers' in model_layers:
|
||||
for layer in model_layers['layers']:
|
||||
if str.isdigit(layer['name']):
|
||||
s += layer['memory']
|
||||
else:
|
||||
s += self.cal_transformer_memory(layer)
|
||||
return s
|
||||
|
||||
def cal_non_transformer_memory(self, model):
|
||||
# total memory used
|
||||
model_memory = model['layers'][0]['memory']
|
||||
model_layers = model['layers'][0]
|
||||
transformer_layer_memory = self.cal_transformer_memory(model_layers)
|
||||
transformer_layer_memory = self.transformer_module['memory']
|
||||
non_size = model_memory - transformer_layer_memory
|
||||
print_rank_0(f"non size {model_memory} {non_size}")
|
||||
return non_size
|
||||
|
||||
def dfs_best(self, g, idx, config):
|
||||
def layers_combination_init(self, g, idx, config):
|
||||
if idx == 0:
|
||||
self.layer_full_recompute_combination = LayerCombination({
|
||||
"name": "full_recompute",
|
||||
"num": config["nlayer"],
|
||||
"memory": config["chp_input"],
|
||||
"cost": config["chp_time"],
|
||||
"broadcast_value": 0,
|
||||
"policy_name": "n_full"
|
||||
})
|
||||
self.layers_combination.append(self.layer_full_recompute_combination)
|
||||
self.layer_without_recompute_combination = LayerCombination({
|
||||
"name": "without_recompute",
|
||||
"num": config["nlayer"],
|
||||
"memory": config["full_activation"],
|
||||
"cost": 0,
|
||||
"broadcast_value": 2,
|
||||
"policy_name": "n_without"
|
||||
})
|
||||
self.layers_combination.append(self.layer_without_recompute_combination)
|
||||
if idx >= len(config['layers']):
|
||||
self.search_recompute_policy(g, config)
|
||||
recompute_nodes = self.get_recompute_op(g)
|
||||
if len(recompute_nodes) == len(config['layers']) or len(recompute_nodes) == 0:
|
||||
return
|
||||
stash_mem_per_layer, recompute_cost = self.calculate_cost_mem(g, 0)
|
||||
self.layer_recompute_one_combination = LayerCombination({
|
||||
"name": ",".join(recompute_nodes),
|
||||
"num": config["nlayer"],
|
||||
"memory": stash_mem_per_layer,
|
||||
"cost": recompute_cost,
|
||||
"broadcast_value": 1,
|
||||
"policy_name": "n_selective"
|
||||
})
|
||||
self.layers_combination.append(self.layer_recompute_one_combination)
|
||||
return
|
||||
g.nodes[idx]['recompute'] = False
|
||||
self.dfs_best(g, idx + 1, config)
|
||||
self.layers_combination_init(g, idx + 1, config)
|
||||
g.nodes[idx]['recompute'] = True
|
||||
self.dfs_best(g, idx + 1, config)
|
||||
self.layers_combination_init(g, idx + 1, config)
|
||||
|
||||
def search_recompute_policy(self, g, config):
|
||||
stash_mem_per_layer, recompute_cost = self.calculate_cost_mem(g, 0)
|
||||
peek = self.calculate_cost_peek(g, 0, 0, 0)
|
||||
for i in range(config['nlayer']):
|
||||
# if it is selective
|
||||
stash_mem_total = (stash_mem_per_layer * i + config['full_activation'] * (config['nlayer'] - i)) * config['pp']
|
||||
if config['static_memory_layer'] + stash_mem_total + peek < config['device_memory']:
|
||||
recompute_total = recompute_cost * i # * config['pp']
|
||||
if recompute_total < self.total_recompute_cost:
|
||||
self.total_recompute_cost = recompute_total
|
||||
self.print_recompute_op(g)
|
||||
self.recompute_policy['config'] = 'n_selective'
|
||||
self.recompute_policy['policy'] = g.copy()
|
||||
self.recompute_policy['n'] = i
|
||||
try:
|
||||
print_rank_0(
|
||||
f"recompute policy {i}-selective: {config['static_memory_layer'] / 1024:.1f} GiB + "
|
||||
f"{stash_mem_total / 1024:.1f} GiB + {peek / 1024:.1f} GiB, "
|
||||
f"speed up compared with all recompute"
|
||||
f" {(self.total_forward_cost - recompute_total) / (4 * self.total_forward_cost) * 100:.2f}%")
|
||||
except ZeroDivisionError:
|
||||
print_rank_0("param error. total_forward_cost is 0.")
|
||||
def get_max_goods_value(self, idx, ans, config):
|
||||
i, j, k = idx[0], idx[1], idx[2]
|
||||
pre_step_ans = ans[i - 1][j - k]
|
||||
if k == 0:
|
||||
return pre_step_ans
|
||||
|
||||
# if there are not enough memory
|
||||
stash_mem_total = (stash_mem_per_layer * (config['nlayer'] - i) + config['chp_input'] * i) * config['pp']
|
||||
if config['static_memory_layer'] + stash_mem_total + peek < config['device_memory']:
|
||||
recompute_total = (
|
||||
recompute_cost * (config['nlayer'] - i) + config['chp_time'] * i)
|
||||
if recompute_total < self.total_recompute_cost:
|
||||
self.total_recompute_cost = recompute_total
|
||||
self.print_recompute_op(g)
|
||||
self.recompute_policy['config'] = 'n_full'
|
||||
self.recompute_policy['policy'] = g.copy()
|
||||
self.recompute_policy['n'] = i
|
||||
try:
|
||||
print_rank_0(
|
||||
f"recompute policy {i}-full: {config['static_memory_layer'] / 1024:.1f} GiB + "
|
||||
f"{stash_mem_total / 1024:.1f} ({stash_mem_per_layer * (config['nlayer'] - i)} + "
|
||||
f"{config['chp_input'] * i}) GiB + {peek / 1024:.1f} GiB, "
|
||||
f"speed up compared with all recompute "
|
||||
f"{(self.total_forward_cost - recompute_total) / (4 * self.total_forward_cost) * 100:.2f}%")
|
||||
except ZeroDivisionError:
|
||||
print_rank_0("param error. total_forward_cost is 0.")
|
||||
goods_value = ans[i][j]
|
||||
memory = pre_step_ans.memory + k * self.layers_combination[i].memory
|
||||
cost = pre_step_ans.cost + k * self.layers_combination[i].cost
|
||||
if pre_step_ans.cost == float('inf'):
|
||||
cost = k * self.layers_combination[i].cost
|
||||
try:
|
||||
device_memory = max(config["device_memory"] - config["static_memory_layer"], 0) / config["pp"]
|
||||
except ZeroDivisionError:
|
||||
device_memory = max(config["device_memory"] - config["static_memory_layer"], 0)
|
||||
print_rank_0("[ERROR] pipeline model parallel world size is 0. ")
|
||||
|
||||
def analyse_policy_to_list(self, full_model, recompute_n, recompute_nodes):
|
||||
if "config" in self.recompute_policy and self.recompute_policy["config"] != "n_full":
|
||||
recompute_policy_list = [int(recompute_n), 1]
|
||||
else:
|
||||
recompute_policy_list = [int(recompute_n), 0]
|
||||
for layer in full_model[0]['layers']:
|
||||
if layer["name"] in recompute_nodes:
|
||||
recompute_policy_list.append(1)
|
||||
if device_memory >= memory and cost <= goods_value.cost:
|
||||
goods_value.memory = memory
|
||||
goods_value.cost = cost
|
||||
goods_value.layer_names.clear()
|
||||
if len(pre_step_ans.layer_names) > 0:
|
||||
goods_value.layer_names.extend(pre_step_ans.layer_names)
|
||||
goods_value.layer_names.extend(self.layers_combination[i].name for _ in range(k))
|
||||
|
||||
return goods_value
|
||||
|
||||
def print_recompute_policy(self, memory, cost, config):
|
||||
fmt_str = "With selective recompute:\n"
|
||||
for k, v in self.recompute_policy.items():
|
||||
if k == self.layer_full_recompute_combination.name:
|
||||
policy_name = self.layer_full_recompute_combination.policy_name
|
||||
elif k == self.layer_without_recompute_combination.name:
|
||||
policy_name = self.layer_without_recompute_combination.policy_name
|
||||
else:
|
||||
policy_name = self.layer_recompute_one_combination.policy_name
|
||||
fmt_str += "recomputeNodes=[{}], ".format(k)
|
||||
fmt_str += "{} {}; ".format(v, policy_name)
|
||||
all_recompute_cost = len(self.layers_module["layers"]) * self.layer_full_recompute_combination.cost
|
||||
try:
|
||||
performance = (all_recompute_cost - cost) / (all_recompute_cost * 4)
|
||||
except ZeroDivisionError:
|
||||
performance = 0
|
||||
print_rank_0("[ERROR] all recompute cost is 0. ")
|
||||
fmt_str += "\ntotal mem cost: {:.1f} GiB + {:.1f} GiB, speed up compared with all recompute {:.2%}".format(
|
||||
config["static_memory_layer"] / 1024, memory * config["pp"] / 1024, performance)
|
||||
print_rank_0(fmt_str)
|
||||
|
||||
def get_all_layer_policy(self, combination_num, layer_num, ans, config):
|
||||
layer_nodes = [self.layer_full_recompute_combination.name for _ in range(layer_num)]
|
||||
memory = layer_num * self.layer_full_recompute_combination.memory
|
||||
cost = layer_num * self.layer_full_recompute_combination.cost
|
||||
for i in range(layer_num, 0, -1):
|
||||
size = layer_num - len(ans[combination_num][i].layer_names)
|
||||
if size != layer_num:
|
||||
l_nodes = []
|
||||
l_nodes.extend(ans[combination_num][i].layer_names)
|
||||
# if the policies of all layers are not found, the remaining layers ues all recompute policy.
|
||||
l_nodes.extend(self.layer_full_recompute_combination.name for _ in range(size))
|
||||
l_memory = ans[combination_num][i].memory + size * self.layer_full_recompute_combination.memory
|
||||
l_cost = ans[combination_num][i].cost + size * self.layer_full_recompute_combination.cost
|
||||
if l_cost < cost:
|
||||
cost = l_cost
|
||||
memory = l_memory
|
||||
layer_nodes.clear()
|
||||
layer_nodes.extend(l_nodes)
|
||||
|
||||
for nodes in layer_nodes:
|
||||
if nodes not in self.recompute_policy.keys():
|
||||
self.recompute_policy.update({nodes: 1})
|
||||
continue
|
||||
recompute_policy_list.append(0)
|
||||
self.recompute_policy.update({nodes: self.recompute_policy[nodes] + 1})
|
||||
|
||||
self.print_recompute_policy(memory, cost, config)
|
||||
|
||||
def knapsack_best(self, config):
|
||||
combination_num = len(self.layers_combination)
|
||||
layer_num = len(self.layers_module["layers"])
|
||||
# make combination index id begin for 1.
|
||||
self.layers_combination.insert(0, None)
|
||||
# init ans
|
||||
ans = [[GoodsValue() for _ in range(layer_num + 1)] for _ in range(combination_num + 1)]
|
||||
# find max goods value
|
||||
for i in range(1, combination_num + 1):
|
||||
for j in range(layer_num + 1):
|
||||
k = 0
|
||||
while k <= self.layers_combination[i].num and k <= j:
|
||||
ans[i][j] = self.get_max_goods_value([i, j, k], ans, config)
|
||||
k += 1
|
||||
self.get_all_layer_policy(combination_num, layer_num, ans, config)
|
||||
|
||||
def analyse_policy_to_list(self):
|
||||
recompute_policy_list = []
|
||||
full_module_layers = self.layers_module["layers"][0]["layers"]
|
||||
module_layers_num = len(full_module_layers)
|
||||
for nodes_name, v in self.recompute_policy.items():
|
||||
nodes_count = [v]
|
||||
if nodes_name == self.layer_without_recompute_combination.name:
|
||||
broadcast_value = self.layer_without_recompute_combination.broadcast_value
|
||||
nodes_count.extend(broadcast_value for _ in range(module_layers_num + 1))
|
||||
elif nodes_name == self.layer_full_recompute_combination.name:
|
||||
broadcast_value = self.layer_full_recompute_combination.broadcast_value
|
||||
nodes_count.extend(broadcast_value for _ in range(module_layers_num + 1))
|
||||
else:
|
||||
nodes_count.append(self.layer_recompute_one_combination.broadcast_value)
|
||||
recompute_nodes = nodes_name.split(",")
|
||||
for layer in full_module_layers:
|
||||
if layer["name"] in recompute_nodes:
|
||||
nodes_count.append(self.layer_recompute_one_combination.broadcast_value)
|
||||
continue
|
||||
nodes_count.append(self.layer_without_recompute_combination.broadcast_value)
|
||||
recompute_policy_list.append(nodes_count)
|
||||
return recompute_policy_list
|
||||
|
||||
def print_list_to_policy(self, recompute_policy_list):
|
||||
layer_names = self.layers_module["layers"][0]["layers"]
|
||||
module_layers_num = len(layer_names)
|
||||
if len(recompute_policy_list) == 0:
|
||||
return
|
||||
fmt_str = ">> final selective strategy <<\n"
|
||||
for policy in recompute_policy_list:
|
||||
n = policy[0]
|
||||
if policy[1] == self.layer_without_recompute_combination.broadcast_value:
|
||||
policy_name = self.layer_without_recompute_combination.policy_name
|
||||
elif policy[1] == self.layer_full_recompute_combination.broadcast_value:
|
||||
policy_name = self.layer_full_recompute_combination.policy_name
|
||||
else:
|
||||
policy_name = self.layer_recompute_one_combination.policy_name
|
||||
policy = policy[2:]
|
||||
nodes = []
|
||||
for i in range(module_layers_num):
|
||||
if policy[i] == self.layer_recompute_one_combination.broadcast_value:
|
||||
nodes.append(layer_names[i]["name"])
|
||||
fmt_str += "recomputeNodes=[{}], ".format(",".join(nodes))
|
||||
fmt_str += "{} {}\n".format(n, policy_name)
|
||||
print_rank_0(fmt_str)
|
||||
|
||||
def get_layers_module(self, model):
|
||||
if "name" in model and model["name"] == "layers":
|
||||
self.layers_module = model
|
||||
return
|
||||
return True
|
||||
if "layers" not in model:
|
||||
return
|
||||
return False
|
||||
has_transformer_layer = False
|
||||
for sub_model in model["layers"]:
|
||||
self.get_layers_module(sub_model)
|
||||
has_transformer_layer = (has_transformer_layer or self.get_layers_module(sub_model))
|
||||
if has_transformer_layer:
|
||||
self.transformer_module = model
|
||||
return False
|
||||
|
||||
|
||||
class LayerCombination:
|
||||
def __init__(self, config):
|
||||
self.name = config["name"]
|
||||
self.num = config["num"]
|
||||
self.memory = config["memory"]
|
||||
self.cost = config["cost"]
|
||||
self.broadcast_value = config["broadcast_value"]
|
||||
self.policy_name = config["policy_name"]
|
||||
|
||||
|
||||
class GoodsValue:
|
||||
def __init__(self):
|
||||
self.layer_names = []
|
||||
self.memory = 0
|
||||
self.cost = float('inf')
|
||||
|
||||
|
||||
def solve_graph(model, pp, device_memory):
|
||||
@ -256,25 +394,10 @@ def generate_recompute_policy(solver, config):
|
||||
print_rank_0(
|
||||
f"With all recompute: total mem cost: {static_memory / 1024:.1f} GiB + {stash_mem_total / 1024:.1f} GiB + "
|
||||
f"{peek / 1024:.1f} GiB, total recompute all")
|
||||
|
||||
print_rank_0("With selective recompute:")
|
||||
solver.dfs_best(dg, 1, config)
|
||||
if 'policy' not in solver.recompute_policy:
|
||||
solver.recompute_policy['policy'] = dg
|
||||
if 'n' not in solver.recompute_policy:
|
||||
solver.recompute_policy['n'] = num_layers
|
||||
rg = solver.recompute_policy['policy']
|
||||
recompute_nodes = []
|
||||
for node in rg.nodes:
|
||||
if rg.nodes[node]['recompute']:
|
||||
recompute_nodes.append(rg.nodes[node]['name'])
|
||||
recompute_n = solver.recompute_policy['n']
|
||||
if "config" in solver.recompute_policy:
|
||||
print_rank_0(f'recompute nodes = {recompute_nodes}, {recompute_n} {solver.recompute_policy["config"]}')
|
||||
full_model = solver.layers_module['layers']
|
||||
|
||||
recompute_policy_new = solver.analyse_policy_to_list(full_model, recompute_n, recompute_nodes)
|
||||
solver.layers_combination_init(dg, 0, config)
|
||||
solver.knapsack_best(config)
|
||||
recompute_policy_new = solver.analyse_policy_to_list()
|
||||
if parallel_state.get_tensor_model_parallel_world_size() > 1:
|
||||
recompute_policy_new = solver.broadcast_recompute_policy_in_mp(recompute_policy_new)
|
||||
print_rank_0(f'recompute_policy_new = {recompute_policy_new}')
|
||||
solver.apply_policy_to_model(recompute_policy_new, full_model)
|
||||
solver.apply_policy_to_model(recompute_policy_new)
|
||||
solver.print_list_to_policy(recompute_policy_new)
|
||||
|
@ -63,7 +63,7 @@ ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
|
||||
|
||||
# Main script
|
||||
deepspeed pretrain_llama.py \
|
||||
--auto-recompute-device-size 51200 \
|
||||
--auto-recompute-device-size 48128 \
|
||||
--use-fused-rotary-pos-emb \
|
||||
--triangle-attn \
|
||||
--triangle-block-size 1024 \
|
||||
|
346
tests/ut/module/test_auto_recomputing.py
Normal file
346
tests/ut/module/test_auto_recomputing.py
Normal file
@ -0,0 +1,346 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import json
|
||||
import torch
|
||||
from ascendspeed.core import parallel_state
|
||||
from ascendspeed.core.memory.auto_recomputing.autorecompute_solver import solve_graph
|
||||
|
||||
|
||||
class TestAutoRecomputing(unittest.TestCase):
|
||||
transformer_layer_info_fmt = """
|
||||
{
|
||||
"name": "%d",
|
||||
"layers": [
|
||||
{
|
||||
"name": "input_layernorm",
|
||||
"memory": 384,
|
||||
"time": 1.9710063934326172,
|
||||
"input": 64.0,
|
||||
"peak_memory": 402705408,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 3.9420127868652344
|
||||
}, {
|
||||
"name": "attention",
|
||||
"layers": [{
|
||||
"name": "query_key_value",
|
||||
"memory": 192,
|
||||
"time": 9.331226348876953,
|
||||
"input": 64.0,
|
||||
"peak_memory": 402654208,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 18.662452697753906
|
||||
}, {
|
||||
"name": "rotary_emb",
|
||||
"memory": 0,
|
||||
"time": 1.7354488372802734,
|
||||
"input": 64.0,
|
||||
"peak_memory": 0,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 3.470897674560547
|
||||
}, {
|
||||
"name": "triangle_attn",
|
||||
"layers": [{
|
||||
"name": "scaled_masked_softmax",
|
||||
"memory": 512,
|
||||
"time": 465.08251536976206,
|
||||
"input": 516.0,
|
||||
"peak_memory": 542107136,
|
||||
"forward_cnt": 11,
|
||||
"pre_total_time": 5115.907669067383
|
||||
}],
|
||||
"memory": 1664,
|
||||
"time": 22.87912368774414,
|
||||
"input": 208.0,
|
||||
"peak_memory": 2818581504,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 45.75824737548828
|
||||
}, {
|
||||
"name": "dense",
|
||||
"memory": 64,
|
||||
"time": 8.333802223205566,
|
||||
"input": 64.0,
|
||||
"peak_memory": 536871936,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 16.667604446411133
|
||||
}],
|
||||
"memory": 1792,
|
||||
"time": 50.97508430480957,
|
||||
"input": 80.0,
|
||||
"peak_memory": 2684364288,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 101.95016860961914
|
||||
}, {
|
||||
"name": "post_attention_layernorm",
|
||||
"memory": 384,
|
||||
"time": 1.8906593322753906,
|
||||
"input": 64.0,
|
||||
"peak_memory": 402705408,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 3.7813186645507812
|
||||
}, {
|
||||
"name": "mlp",
|
||||
"layers": [{
|
||||
"name": "gate_proj",
|
||||
"memory": 172,
|
||||
"time": 9.36591625213623,
|
||||
"input": 64.0,
|
||||
"peak_memory": 360711168,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 18.73183250427246
|
||||
}, {
|
||||
"name": "up_proj",
|
||||
"memory": 172,
|
||||
"time": 8.879423141479492,
|
||||
"input": 64.0,
|
||||
"peak_memory": 360711168,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 17.758846282958984
|
||||
}, {
|
||||
"name": "down_proj",
|
||||
"memory": 64,
|
||||
"time": 13.797521591186523,
|
||||
"input": 172.0,
|
||||
"peak_memory": 536871936,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 27.595043182373047
|
||||
}],
|
||||
"memory": 752,
|
||||
"time": 38.39600086212158,
|
||||
"input": 64.0,
|
||||
"peak_memory": 1258294272,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 76.79200172424316
|
||||
}],
|
||||
"memory": 3312,
|
||||
"time": 100.17907619476318,
|
||||
"input": 64.0,
|
||||
"peak_memory": 3942760960,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 200.35815238952637
|
||||
}
|
||||
"""
|
||||
module_all_fmt = """
|
||||
{
|
||||
"module": [],
|
||||
"layers": [{
|
||||
"name": "module",
|
||||
"layers": [
|
||||
{
|
||||
"name": "module",
|
||||
"layers": [
|
||||
{
|
||||
"name": "embedding",
|
||||
"layers": [
|
||||
{
|
||||
"name": "word_embeddings",
|
||||
"memory": 256,
|
||||
"time": 13.043999671936035,
|
||||
"input": 0.25,
|
||||
"peak_memory": 268797952,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 26.08799934387207
|
||||
}],
|
||||
"memory": 64,
|
||||
"time": 16.85166358947754,
|
||||
"input": 0.25,
|
||||
"peak_memory": 604310016,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 33.70332717895508
|
||||
},
|
||||
{
|
||||
"name": "language_model",
|
||||
"layers": [
|
||||
{
|
||||
"name": "layers",
|
||||
"layers": [%s]
|
||||
}],
|
||||
"memory": 4336,
|
||||
"time": 1621.1401224136353,
|
||||
"input": 80.0,
|
||||
"peak_memory": 5331085312,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 3242.2802448272705
|
||||
}],
|
||||
"memory": 4336,
|
||||
"time": 1642.3271894454956,
|
||||
"input": 16.25,
|
||||
"peak_memory": 5398523392,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 3284.654378890991
|
||||
}],
|
||||
"memory": 4336,
|
||||
"time": 1645.2174186706543,
|
||||
"input": 16.25,
|
||||
"peak_memory": 5398523392,
|
||||
"forward_cnt": 2,
|
||||
"pre_total_time": 3290.4348373413086
|
||||
}],
|
||||
"used_mem": 16600,
|
||||
"max_device_memory": 58960
|
||||
}
|
||||
"""
|
||||
|
||||
def get_module(self, size):
|
||||
module_layers = [self.transformer_layer_info_fmt % i for i in range(size)]
|
||||
module_layers_context = self.module_all_fmt % (",".join(module_layers))
|
||||
module = json.loads(module_layers_context)
|
||||
return module
|
||||
|
||||
def get_transformer_layers(self, module):
|
||||
transformer_layers = None
|
||||
for sub_module in module["layers"]:
|
||||
if sub_module["name"] == "layers":
|
||||
transformer_layers = sub_module["layers"]
|
||||
break
|
||||
if "layers" not in sub_module:
|
||||
continue
|
||||
transformer_layers = self.get_transformer_layers(sub_module)
|
||||
return transformer_layers
|
||||
|
||||
@staticmethod
|
||||
def is_recompute_module(module):
|
||||
if "recompute" in module and module["recompute"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_module_recompute_layer(self, module):
|
||||
recompute_module_layer = []
|
||||
for sub_module in module:
|
||||
if self.is_recompute_module(sub_module):
|
||||
recompute_module_layer.append(sub_module["name"])
|
||||
return recompute_module_layer
|
||||
|
||||
def assert_policy(self, module, policy):
|
||||
transformer_layers = self.get_transformer_layers(module)
|
||||
for module in transformer_layers:
|
||||
# n_full
|
||||
if self.is_recompute_module(module):
|
||||
if "n_full" not in policy:
|
||||
return False
|
||||
if policy["n_full"] <= 0:
|
||||
return False
|
||||
policy["n_full"] -= 1
|
||||
continue
|
||||
sub_module_recompute_layer = self.get_module_recompute_layer(module["layers"])
|
||||
# n_without
|
||||
if len(sub_module_recompute_layer) == 0:
|
||||
if "n_without" not in policy:
|
||||
return False
|
||||
if policy["n_without"] <= 0:
|
||||
return False
|
||||
policy["n_without"] -= 1
|
||||
continue
|
||||
# n_selective
|
||||
if "n_selective" not in policy or "n_selective_recompute_nodes" not in policy:
|
||||
return False
|
||||
if policy["n_selective"] <= 0:
|
||||
return False
|
||||
if len(sub_module_recompute_layer) != len(policy["n_selective_recompute_nodes"]):
|
||||
return False
|
||||
if len(set(sub_module_recompute_layer) | set(policy["n_selective_recompute_nodes"])) != len(
|
||||
policy["n_selective_recompute_nodes"]):
|
||||
return False
|
||||
policy["n_selective"] -= 1
|
||||
return True
|
||||
|
||||
def do_solve_graph(self, layer_num, pp, device_memory):
|
||||
module = self.get_module(layer_num)
|
||||
parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = 1
|
||||
solve_graph(module, pp, device_memory)
|
||||
return module
|
||||
|
||||
def test_solve_graph_by_module_10_layer_pp_2_52G(self):
|
||||
print("=== start to test solve graph: module 10 layer, pp 2, memory 52GB ===")
|
||||
module = self.do_solve_graph(10, 2, 52 * 1024)
|
||||
policy = {
|
||||
"n_without": 4,
|
||||
"n_full": 1,
|
||||
"n_selective": 5,
|
||||
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_10_layer_pp_2_54G(self):
|
||||
print("=== start to test solve graph: module 10 layer, pp 2, memory 54GB ===")
|
||||
module = self.do_solve_graph(10, 2, 54 * 1024)
|
||||
policy = {
|
||||
"n_without": 1,
|
||||
"n_full": 3,
|
||||
"n_selective": 6,
|
||||
"n_selective_recompute_nodes": ["input_layernorm", "post_attention_layernorm"]
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_10_layer_pp_1_52G(self):
|
||||
print("=== start to test solve graph: module 10 layer, pp 1, memory 52GB ===")
|
||||
module = self.do_solve_graph(10, 1, 52 * 1024)
|
||||
policy = {
|
||||
"n_without": 10
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_10_layer_pp_1_54G(self):
|
||||
print("=== start to test solve graph: module 10 layer, pp 1, memory 54GB ===")
|
||||
module = self.do_solve_graph(10, 1, 54 * 1024)
|
||||
policy = {
|
||||
"n_without": 10
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_32_layer_pp_2_52G(self):
|
||||
print("=== start to test solve graph: module 32 layer, pp 2, memory 52GB ===")
|
||||
module = self.do_solve_graph(32, 2, 52 * 1024)
|
||||
policy = {
|
||||
"n_full": 13,
|
||||
"n_selective": 19,
|
||||
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_32_layer_pp_2_54G(self):
|
||||
print("=== start to test solve graph: module 32 layer, pp 2, memory 54GB ===")
|
||||
module = self.do_solve_graph(32, 2, 54 * 1024)
|
||||
policy = {
|
||||
"n_full": 12,
|
||||
"n_selective": 20,
|
||||
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_32_layer_pp_1_52G(self):
|
||||
print("=== start to test solve graph: module 32 layer, pp 1, memory 52GB ===")
|
||||
module = self.do_solve_graph(32, 1, 52 * 1024)
|
||||
policy = {
|
||||
"n_without": 2,
|
||||
"n_selective": 30,
|
||||
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
def test_solve_graph_by_module_32_layer_pp_1_54G(self):
|
||||
print("=== start to test solve graph: module 32 layer, pp 1, memory 54GB ===")
|
||||
module = self.do_solve_graph(32, 1, 54 * 1024)
|
||||
policy = {
|
||||
"n_without": 3,
|
||||
"n_selective": 29,
|
||||
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
|
||||
}
|
||||
self.assertTrue(self.assert_policy(module, policy))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user