!330 支持自适应选择重计算-优化重计算策略选择算法,加上UT

Merge pull request !330 from Gou Jicheng/dev_from_master_only_recompute
This commit is contained in:
i-robot 2023-11-30 11:15:49 +00:00 committed by Gitee
commit a5d76b70b1
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 584 additions and 115 deletions

View File

@ -13,15 +13,20 @@ class GraphSolver:
self.total_recompute_cost = 0
self.total_forward_cost = 0
self.layers_module = None
self.transformer_module = None
self.recompute_policy = {}
self.layers_combination = []
self.layer_full_recompute_combination = None
self.layer_without_recompute_combination = None
self.layer_recompute_one_combination = None
@staticmethod
def print_recompute_op(graph):
def get_recompute_op(graph):
recompute_nodes = []
for node in graph.nodes:
if graph.nodes[node]['recompute']:
recompute_nodes.append(graph.nodes[node]['name'])
print_rank_0(f'recompute nodes = {recompute_nodes}')
return recompute_nodes
@staticmethod
def dg_init(no_recompute_layer):
@ -48,28 +53,39 @@ class GraphSolver:
group=parallel_state.get_tensor_model_parallel_group())
return recompute_policy_tensor.cpu().numpy().tolist()
@staticmethod
def apply_policy_to_model(recompute_policy_new, full_model):
if recompute_policy_new[1] == 0:
for idx, module in enumerate(full_model):
if idx < recompute_policy_new[0]:
module['recompute'] = True
else:
module['recompute'] = False
idx_submodule = 0
for layer in module['layers']:
if recompute_policy_new[idx_submodule + 2] == 1:
layer['recompute'] = True
idx_submodule += 1
else:
for idx, module in enumerate(full_model):
if idx < recompute_policy_new[0]:
module['recompute'] = False
idx_submodule = 0
for layer in module['layers']:
if recompute_policy_new[idx_submodule + 2] == 1:
layer['recompute'] = True
idx_submodule += 1
def set_recompute_info_to_module(self, module, recompute_nodes, recompute):
if not recompute:
module["recompute"] = False
for layer in module["layers"]:
layer["recompute"] = False
return
if len(recompute_nodes) == 0:
module["recompute"] = True
return
sub_modules = module["layers"]
recompute_nodes_length = len(recompute_nodes)
for i in range(recompute_nodes_length):
if recompute_nodes[i] == self.layer_recompute_one_combination.broadcast_value:
sub_modules[i]["recompute"] = True
continue
sub_modules[i]["recompute"] = False
def apply_policy_to_model(self, recompute_policy_list):
full_layers = self.layers_module["layers"]
if len(recompute_policy_list) == 0:
return
idx = 0
for policy in recompute_policy_list:
n = policy[0]
recompute = False
recompute_nodes = []
if policy[1] != self.layer_without_recompute_combination.broadcast_value:
recompute = True
if policy[1] == self.layer_recompute_one_combination.broadcast_value:
recompute_nodes = policy[2:]
for i in range(idx, idx + n):
self.set_recompute_info_to_module(full_layers[i], recompute_nodes, recompute)
idx += n
# minimize the number of memory, results in all recompute
def calculate_cost_mem(self, g: nx.DiGraph, idx):
@ -115,98 +131,220 @@ class GraphSolver:
successor_cnt += 1
return global_max_mem
def cal_transformer_memory(self, model_layers):
s = 0
if 'layers' in model_layers:
for layer in model_layers['layers']:
if str.isdigit(layer['name']):
s += layer['memory']
else:
s += self.cal_transformer_memory(layer)
return s
def cal_non_transformer_memory(self, model):
# total memory used
model_memory = model['layers'][0]['memory']
model_layers = model['layers'][0]
transformer_layer_memory = self.cal_transformer_memory(model_layers)
transformer_layer_memory = self.transformer_module['memory']
non_size = model_memory - transformer_layer_memory
print_rank_0(f"non size {model_memory} {non_size}")
return non_size
def dfs_best(self, g, idx, config):
def layers_combination_init(self, g, idx, config):
if idx == 0:
self.layer_full_recompute_combination = LayerCombination({
"name": "full_recompute",
"num": config["nlayer"],
"memory": config["chp_input"],
"cost": config["chp_time"],
"broadcast_value": 0,
"policy_name": "n_full"
})
self.layers_combination.append(self.layer_full_recompute_combination)
self.layer_without_recompute_combination = LayerCombination({
"name": "without_recompute",
"num": config["nlayer"],
"memory": config["full_activation"],
"cost": 0,
"broadcast_value": 2,
"policy_name": "n_without"
})
self.layers_combination.append(self.layer_without_recompute_combination)
if idx >= len(config['layers']):
self.search_recompute_policy(g, config)
recompute_nodes = self.get_recompute_op(g)
if len(recompute_nodes) == len(config['layers']) or len(recompute_nodes) == 0:
return
stash_mem_per_layer, recompute_cost = self.calculate_cost_mem(g, 0)
self.layer_recompute_one_combination = LayerCombination({
"name": ",".join(recompute_nodes),
"num": config["nlayer"],
"memory": stash_mem_per_layer,
"cost": recompute_cost,
"broadcast_value": 1,
"policy_name": "n_selective"
})
self.layers_combination.append(self.layer_recompute_one_combination)
return
g.nodes[idx]['recompute'] = False
self.dfs_best(g, idx + 1, config)
self.layers_combination_init(g, idx + 1, config)
g.nodes[idx]['recompute'] = True
self.dfs_best(g, idx + 1, config)
self.layers_combination_init(g, idx + 1, config)
def search_recompute_policy(self, g, config):
stash_mem_per_layer, recompute_cost = self.calculate_cost_mem(g, 0)
peek = self.calculate_cost_peek(g, 0, 0, 0)
for i in range(config['nlayer']):
# if it is selective
stash_mem_total = (stash_mem_per_layer * i + config['full_activation'] * (config['nlayer'] - i)) * config['pp']
if config['static_memory_layer'] + stash_mem_total + peek < config['device_memory']:
recompute_total = recompute_cost * i # * config['pp']
if recompute_total < self.total_recompute_cost:
self.total_recompute_cost = recompute_total
self.print_recompute_op(g)
self.recompute_policy['config'] = 'n_selective'
self.recompute_policy['policy'] = g.copy()
self.recompute_policy['n'] = i
try:
print_rank_0(
f"recompute policy {i}-selective: {config['static_memory_layer'] / 1024:.1f} GiB + "
f"{stash_mem_total / 1024:.1f} GiB + {peek / 1024:.1f} GiB, "
f"speed up compared with all recompute"
f" {(self.total_forward_cost - recompute_total) / (4 * self.total_forward_cost) * 100:.2f}%")
except ZeroDivisionError:
print_rank_0("param error. total_forward_cost is 0.")
def get_max_goods_value(self, idx, ans, config):
i, j, k = idx[0], idx[1], idx[2]
pre_step_ans = ans[i - 1][j - k]
if k == 0:
return pre_step_ans
# if there are not enough memory
stash_mem_total = (stash_mem_per_layer * (config['nlayer'] - i) + config['chp_input'] * i) * config['pp']
if config['static_memory_layer'] + stash_mem_total + peek < config['device_memory']:
recompute_total = (
recompute_cost * (config['nlayer'] - i) + config['chp_time'] * i)
if recompute_total < self.total_recompute_cost:
self.total_recompute_cost = recompute_total
self.print_recompute_op(g)
self.recompute_policy['config'] = 'n_full'
self.recompute_policy['policy'] = g.copy()
self.recompute_policy['n'] = i
try:
print_rank_0(
f"recompute policy {i}-full: {config['static_memory_layer'] / 1024:.1f} GiB + "
f"{stash_mem_total / 1024:.1f} ({stash_mem_per_layer * (config['nlayer'] - i)} + "
f"{config['chp_input'] * i}) GiB + {peek / 1024:.1f} GiB, "
f"speed up compared with all recompute "
f"{(self.total_forward_cost - recompute_total) / (4 * self.total_forward_cost) * 100:.2f}%")
except ZeroDivisionError:
print_rank_0("param error. total_forward_cost is 0.")
goods_value = ans[i][j]
memory = pre_step_ans.memory + k * self.layers_combination[i].memory
cost = pre_step_ans.cost + k * self.layers_combination[i].cost
if pre_step_ans.cost == float('inf'):
cost = k * self.layers_combination[i].cost
try:
device_memory = max(config["device_memory"] - config["static_memory_layer"], 0) / config["pp"]
except ZeroDivisionError:
device_memory = max(config["device_memory"] - config["static_memory_layer"], 0)
print_rank_0("[ERROR] pipeline model parallel world size is 0. ")
def analyse_policy_to_list(self, full_model, recompute_n, recompute_nodes):
if "config" in self.recompute_policy and self.recompute_policy["config"] != "n_full":
recompute_policy_list = [int(recompute_n), 1]
else:
recompute_policy_list = [int(recompute_n), 0]
for layer in full_model[0]['layers']:
if layer["name"] in recompute_nodes:
recompute_policy_list.append(1)
if device_memory >= memory and cost <= goods_value.cost:
goods_value.memory = memory
goods_value.cost = cost
goods_value.layer_names.clear()
if len(pre_step_ans.layer_names) > 0:
goods_value.layer_names.extend(pre_step_ans.layer_names)
goods_value.layer_names.extend(self.layers_combination[i].name for _ in range(k))
return goods_value
def print_recompute_policy(self, memory, cost, config):
fmt_str = "With selective recompute:\n"
for k, v in self.recompute_policy.items():
if k == self.layer_full_recompute_combination.name:
policy_name = self.layer_full_recompute_combination.policy_name
elif k == self.layer_without_recompute_combination.name:
policy_name = self.layer_without_recompute_combination.policy_name
else:
policy_name = self.layer_recompute_one_combination.policy_name
fmt_str += "recomputeNodes=[{}], ".format(k)
fmt_str += "{} {}; ".format(v, policy_name)
all_recompute_cost = len(self.layers_module["layers"]) * self.layer_full_recompute_combination.cost
try:
performance = (all_recompute_cost - cost) / (all_recompute_cost * 4)
except ZeroDivisionError:
performance = 0
print_rank_0("[ERROR] all recompute cost is 0. ")
fmt_str += "\ntotal mem cost: {:.1f} GiB + {:.1f} GiB, speed up compared with all recompute {:.2%}".format(
config["static_memory_layer"] / 1024, memory * config["pp"] / 1024, performance)
print_rank_0(fmt_str)
def get_all_layer_policy(self, combination_num, layer_num, ans, config):
layer_nodes = [self.layer_full_recompute_combination.name for _ in range(layer_num)]
memory = layer_num * self.layer_full_recompute_combination.memory
cost = layer_num * self.layer_full_recompute_combination.cost
for i in range(layer_num, 0, -1):
size = layer_num - len(ans[combination_num][i].layer_names)
if size != layer_num:
l_nodes = []
l_nodes.extend(ans[combination_num][i].layer_names)
# if the policies of all layers are not found, the remaining layers ues all recompute policy.
l_nodes.extend(self.layer_full_recompute_combination.name for _ in range(size))
l_memory = ans[combination_num][i].memory + size * self.layer_full_recompute_combination.memory
l_cost = ans[combination_num][i].cost + size * self.layer_full_recompute_combination.cost
if l_cost < cost:
cost = l_cost
memory = l_memory
layer_nodes.clear()
layer_nodes.extend(l_nodes)
for nodes in layer_nodes:
if nodes not in self.recompute_policy.keys():
self.recompute_policy.update({nodes: 1})
continue
recompute_policy_list.append(0)
self.recompute_policy.update({nodes: self.recompute_policy[nodes] + 1})
self.print_recompute_policy(memory, cost, config)
def knapsack_best(self, config):
combination_num = len(self.layers_combination)
layer_num = len(self.layers_module["layers"])
# make combination index id begin for 1.
self.layers_combination.insert(0, None)
# init ans
ans = [[GoodsValue() for _ in range(layer_num + 1)] for _ in range(combination_num + 1)]
# find max goods value
for i in range(1, combination_num + 1):
for j in range(layer_num + 1):
k = 0
while k <= self.layers_combination[i].num and k <= j:
ans[i][j] = self.get_max_goods_value([i, j, k], ans, config)
k += 1
self.get_all_layer_policy(combination_num, layer_num, ans, config)
def analyse_policy_to_list(self):
recompute_policy_list = []
full_module_layers = self.layers_module["layers"][0]["layers"]
module_layers_num = len(full_module_layers)
for nodes_name, v in self.recompute_policy.items():
nodes_count = [v]
if nodes_name == self.layer_without_recompute_combination.name:
broadcast_value = self.layer_without_recompute_combination.broadcast_value
nodes_count.extend(broadcast_value for _ in range(module_layers_num + 1))
elif nodes_name == self.layer_full_recompute_combination.name:
broadcast_value = self.layer_full_recompute_combination.broadcast_value
nodes_count.extend(broadcast_value for _ in range(module_layers_num + 1))
else:
nodes_count.append(self.layer_recompute_one_combination.broadcast_value)
recompute_nodes = nodes_name.split(",")
for layer in full_module_layers:
if layer["name"] in recompute_nodes:
nodes_count.append(self.layer_recompute_one_combination.broadcast_value)
continue
nodes_count.append(self.layer_without_recompute_combination.broadcast_value)
recompute_policy_list.append(nodes_count)
return recompute_policy_list
def print_list_to_policy(self, recompute_policy_list):
layer_names = self.layers_module["layers"][0]["layers"]
module_layers_num = len(layer_names)
if len(recompute_policy_list) == 0:
return
fmt_str = ">> final selective strategy <<\n"
for policy in recompute_policy_list:
n = policy[0]
if policy[1] == self.layer_without_recompute_combination.broadcast_value:
policy_name = self.layer_without_recompute_combination.policy_name
elif policy[1] == self.layer_full_recompute_combination.broadcast_value:
policy_name = self.layer_full_recompute_combination.policy_name
else:
policy_name = self.layer_recompute_one_combination.policy_name
policy = policy[2:]
nodes = []
for i in range(module_layers_num):
if policy[i] == self.layer_recompute_one_combination.broadcast_value:
nodes.append(layer_names[i]["name"])
fmt_str += "recomputeNodes=[{}], ".format(",".join(nodes))
fmt_str += "{} {}\n".format(n, policy_name)
print_rank_0(fmt_str)
def get_layers_module(self, model):
if "name" in model and model["name"] == "layers":
self.layers_module = model
return
return True
if "layers" not in model:
return
return False
has_transformer_layer = False
for sub_model in model["layers"]:
self.get_layers_module(sub_model)
has_transformer_layer = (has_transformer_layer or self.get_layers_module(sub_model))
if has_transformer_layer:
self.transformer_module = model
return False
class LayerCombination:
def __init__(self, config):
self.name = config["name"]
self.num = config["num"]
self.memory = config["memory"]
self.cost = config["cost"]
self.broadcast_value = config["broadcast_value"]
self.policy_name = config["policy_name"]
class GoodsValue:
def __init__(self):
self.layer_names = []
self.memory = 0
self.cost = float('inf')
def solve_graph(model, pp, device_memory):
@ -256,25 +394,10 @@ def generate_recompute_policy(solver, config):
print_rank_0(
f"With all recompute: total mem cost: {static_memory / 1024:.1f} GiB + {stash_mem_total / 1024:.1f} GiB + "
f"{peek / 1024:.1f} GiB, total recompute all")
print_rank_0("With selective recompute:")
solver.dfs_best(dg, 1, config)
if 'policy' not in solver.recompute_policy:
solver.recompute_policy['policy'] = dg
if 'n' not in solver.recompute_policy:
solver.recompute_policy['n'] = num_layers
rg = solver.recompute_policy['policy']
recompute_nodes = []
for node in rg.nodes:
if rg.nodes[node]['recompute']:
recompute_nodes.append(rg.nodes[node]['name'])
recompute_n = solver.recompute_policy['n']
if "config" in solver.recompute_policy:
print_rank_0(f'recompute nodes = {recompute_nodes}, {recompute_n} {solver.recompute_policy["config"]}')
full_model = solver.layers_module['layers']
recompute_policy_new = solver.analyse_policy_to_list(full_model, recompute_n, recompute_nodes)
solver.layers_combination_init(dg, 0, config)
solver.knapsack_best(config)
recompute_policy_new = solver.analyse_policy_to_list()
if parallel_state.get_tensor_model_parallel_world_size() > 1:
recompute_policy_new = solver.broadcast_recompute_policy_in_mp(recompute_policy_new)
print_rank_0(f'recompute_policy_new = {recompute_policy_new}')
solver.apply_policy_to_model(recompute_policy_new, full_model)
solver.apply_policy_to_model(recompute_policy_new)
solver.print_list_to_policy(recompute_policy_new)

View File

@ -63,7 +63,7 @@ ds_args=" --deepspeed-activation-checkpointing ${ds_args}"
# Main script
deepspeed pretrain_llama.py \
--auto-recompute-device-size 51200 \
--auto-recompute-device-size 48128 \
--use-fused-rotary-pos-emb \
--triangle-attn \
--triangle-block-size 1024 \

View File

@ -0,0 +1,346 @@
# coding=utf-8
# Copyright (c) 2023, HUAWEI CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import json
import torch
from ascendspeed.core import parallel_state
from ascendspeed.core.memory.auto_recomputing.autorecompute_solver import solve_graph
class TestAutoRecomputing(unittest.TestCase):
transformer_layer_info_fmt = """
{
"name": "%d",
"layers": [
{
"name": "input_layernorm",
"memory": 384,
"time": 1.9710063934326172,
"input": 64.0,
"peak_memory": 402705408,
"forward_cnt": 2,
"pre_total_time": 3.9420127868652344
}, {
"name": "attention",
"layers": [{
"name": "query_key_value",
"memory": 192,
"time": 9.331226348876953,
"input": 64.0,
"peak_memory": 402654208,
"forward_cnt": 2,
"pre_total_time": 18.662452697753906
}, {
"name": "rotary_emb",
"memory": 0,
"time": 1.7354488372802734,
"input": 64.0,
"peak_memory": 0,
"forward_cnt": 2,
"pre_total_time": 3.470897674560547
}, {
"name": "triangle_attn",
"layers": [{
"name": "scaled_masked_softmax",
"memory": 512,
"time": 465.08251536976206,
"input": 516.0,
"peak_memory": 542107136,
"forward_cnt": 11,
"pre_total_time": 5115.907669067383
}],
"memory": 1664,
"time": 22.87912368774414,
"input": 208.0,
"peak_memory": 2818581504,
"forward_cnt": 2,
"pre_total_time": 45.75824737548828
}, {
"name": "dense",
"memory": 64,
"time": 8.333802223205566,
"input": 64.0,
"peak_memory": 536871936,
"forward_cnt": 2,
"pre_total_time": 16.667604446411133
}],
"memory": 1792,
"time": 50.97508430480957,
"input": 80.0,
"peak_memory": 2684364288,
"forward_cnt": 2,
"pre_total_time": 101.95016860961914
}, {
"name": "post_attention_layernorm",
"memory": 384,
"time": 1.8906593322753906,
"input": 64.0,
"peak_memory": 402705408,
"forward_cnt": 2,
"pre_total_time": 3.7813186645507812
}, {
"name": "mlp",
"layers": [{
"name": "gate_proj",
"memory": 172,
"time": 9.36591625213623,
"input": 64.0,
"peak_memory": 360711168,
"forward_cnt": 2,
"pre_total_time": 18.73183250427246
}, {
"name": "up_proj",
"memory": 172,
"time": 8.879423141479492,
"input": 64.0,
"peak_memory": 360711168,
"forward_cnt": 2,
"pre_total_time": 17.758846282958984
}, {
"name": "down_proj",
"memory": 64,
"time": 13.797521591186523,
"input": 172.0,
"peak_memory": 536871936,
"forward_cnt": 2,
"pre_total_time": 27.595043182373047
}],
"memory": 752,
"time": 38.39600086212158,
"input": 64.0,
"peak_memory": 1258294272,
"forward_cnt": 2,
"pre_total_time": 76.79200172424316
}],
"memory": 3312,
"time": 100.17907619476318,
"input": 64.0,
"peak_memory": 3942760960,
"forward_cnt": 2,
"pre_total_time": 200.35815238952637
}
"""
module_all_fmt = """
{
"module": [],
"layers": [{
"name": "module",
"layers": [
{
"name": "module",
"layers": [
{
"name": "embedding",
"layers": [
{
"name": "word_embeddings",
"memory": 256,
"time": 13.043999671936035,
"input": 0.25,
"peak_memory": 268797952,
"forward_cnt": 2,
"pre_total_time": 26.08799934387207
}],
"memory": 64,
"time": 16.85166358947754,
"input": 0.25,
"peak_memory": 604310016,
"forward_cnt": 2,
"pre_total_time": 33.70332717895508
},
{
"name": "language_model",
"layers": [
{
"name": "layers",
"layers": [%s]
}],
"memory": 4336,
"time": 1621.1401224136353,
"input": 80.0,
"peak_memory": 5331085312,
"forward_cnt": 2,
"pre_total_time": 3242.2802448272705
}],
"memory": 4336,
"time": 1642.3271894454956,
"input": 16.25,
"peak_memory": 5398523392,
"forward_cnt": 2,
"pre_total_time": 3284.654378890991
}],
"memory": 4336,
"time": 1645.2174186706543,
"input": 16.25,
"peak_memory": 5398523392,
"forward_cnt": 2,
"pre_total_time": 3290.4348373413086
}],
"used_mem": 16600,
"max_device_memory": 58960
}
"""
def get_module(self, size):
module_layers = [self.transformer_layer_info_fmt % i for i in range(size)]
module_layers_context = self.module_all_fmt % (",".join(module_layers))
module = json.loads(module_layers_context)
return module
def get_transformer_layers(self, module):
transformer_layers = None
for sub_module in module["layers"]:
if sub_module["name"] == "layers":
transformer_layers = sub_module["layers"]
break
if "layers" not in sub_module:
continue
transformer_layers = self.get_transformer_layers(sub_module)
return transformer_layers
@staticmethod
def is_recompute_module(module):
if "recompute" in module and module["recompute"]:
return True
return False
def get_module_recompute_layer(self, module):
recompute_module_layer = []
for sub_module in module:
if self.is_recompute_module(sub_module):
recompute_module_layer.append(sub_module["name"])
return recompute_module_layer
def assert_policy(self, module, policy):
transformer_layers = self.get_transformer_layers(module)
for module in transformer_layers:
# n_full
if self.is_recompute_module(module):
if "n_full" not in policy:
return False
if policy["n_full"] <= 0:
return False
policy["n_full"] -= 1
continue
sub_module_recompute_layer = self.get_module_recompute_layer(module["layers"])
# n_without
if len(sub_module_recompute_layer) == 0:
if "n_without" not in policy:
return False
if policy["n_without"] <= 0:
return False
policy["n_without"] -= 1
continue
# n_selective
if "n_selective" not in policy or "n_selective_recompute_nodes" not in policy:
return False
if policy["n_selective"] <= 0:
return False
if len(sub_module_recompute_layer) != len(policy["n_selective_recompute_nodes"]):
return False
if len(set(sub_module_recompute_layer) | set(policy["n_selective_recompute_nodes"])) != len(
policy["n_selective_recompute_nodes"]):
return False
policy["n_selective"] -= 1
return True
def do_solve_graph(self, layer_num, pp, device_memory):
module = self.get_module(layer_num)
parallel_state._MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = 1
solve_graph(module, pp, device_memory)
return module
def test_solve_graph_by_module_10_layer_pp_2_52G(self):
print("=== start to test solve graph: module 10 layer, pp 2, memory 52GB ===")
module = self.do_solve_graph(10, 2, 52 * 1024)
policy = {
"n_without": 4,
"n_full": 1,
"n_selective": 5,
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_10_layer_pp_2_54G(self):
print("=== start to test solve graph: module 10 layer, pp 2, memory 54GB ===")
module = self.do_solve_graph(10, 2, 54 * 1024)
policy = {
"n_without": 1,
"n_full": 3,
"n_selective": 6,
"n_selective_recompute_nodes": ["input_layernorm", "post_attention_layernorm"]
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_10_layer_pp_1_52G(self):
print("=== start to test solve graph: module 10 layer, pp 1, memory 52GB ===")
module = self.do_solve_graph(10, 1, 52 * 1024)
policy = {
"n_without": 10
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_10_layer_pp_1_54G(self):
print("=== start to test solve graph: module 10 layer, pp 1, memory 54GB ===")
module = self.do_solve_graph(10, 1, 54 * 1024)
policy = {
"n_without": 10
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_32_layer_pp_2_52G(self):
print("=== start to test solve graph: module 32 layer, pp 2, memory 52GB ===")
module = self.do_solve_graph(32, 2, 52 * 1024)
policy = {
"n_full": 13,
"n_selective": 19,
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_32_layer_pp_2_54G(self):
print("=== start to test solve graph: module 32 layer, pp 2, memory 54GB ===")
module = self.do_solve_graph(32, 2, 54 * 1024)
policy = {
"n_full": 12,
"n_selective": 20,
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_32_layer_pp_1_52G(self):
print("=== start to test solve graph: module 32 layer, pp 1, memory 52GB ===")
module = self.do_solve_graph(32, 1, 52 * 1024)
policy = {
"n_without": 2,
"n_selective": 30,
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
}
self.assertTrue(self.assert_policy(module, policy))
def test_solve_graph_by_module_32_layer_pp_1_54G(self):
print("=== start to test solve graph: module 32 layer, pp 1, memory 54GB ===")
module = self.do_solve_graph(32, 1, 54 * 1024)
policy = {
"n_without": 3,
"n_selective": 29,
"n_selective_recompute_nodes": ["input_layernorm", "attention", "post_attention_layernorm"]
}
self.assertTrue(self.assert_policy(module, policy))
if __name__ == '__main__':
unittest.main()