feat: add minimum spanning tree algorithm

This commit is contained in:
chenluli 2020-07-29 17:55:23 +08:00 committed by Yanyan Wang
parent 30dae14260
commit 83501b91c7
6 changed files with 362 additions and 1 deletions

View File

@ -50,7 +50,7 @@
"site:deploy": "npm run site:build && gh-pages -d public",
"start": "npm run site:develop",
"test": "jest",
"test-live": "DEBUG_MODE=1 jest --watch ./tests/unit/algorithm/find-path-spec",
"test-live": "DEBUG_MODE=1 jest --watch ./tests/unit/algorithm/mst-spec.ts",
"lint-staged:js": "eslint --ext .js,.jsx,.ts,.tsx",
"watch": "father build -w",
"cdn": "antv-bin upload -n @antv/g6"

View File

@ -7,3 +7,4 @@ export { default as floydWarshall } from './floydWarshall'
export { default as getConnectedComponents } from './connected-component'
export { detectAllCycles, detectAllDirectedCycle, detectAllUndirectedCycle } from './detect-cycle'
export { findShortestPath, findAllPath } from './find-path'
export { default as minimumSpanningTree } from './mst'

114
src/algorithm/mst.ts Normal file
View File

@ -0,0 +1,114 @@
import { IGraph } from '../interface/graph';
import { IEdge } from '../interface/item';
import UnionFind from './structs/union-find';
import MinBinaryHeap from './structs/binary-heap'
/**
* Prim algorithmuse priority queue O(E+V*logV), V: 节点数量E: 边的数量
* refer: https://en.wikipedia.org/wiki/Prim%27s_algorithm
* @param graph
* @param weight
*/
const primMST = (graph: IGraph, weight?: string) => {
const selectedEdges = []
const nodes = graph.getNodes()
if (nodes.length === 0) {
return selectedEdges;
}
// 从nodes[0]开始
let currNode = nodes[0];
let visited = new Set()
visited.add(currNode)
// 用二叉堆维护距已加入节点的其他节点的边的权值
const compareWeight = (a: IEdge, b: IEdge) => {
if (weight) {
return (a.getModel()[weight] as number) - (b.getModel()[weight] as number)
} else {
return 0
}
}
let edgeQueue = new MinBinaryHeap(compareWeight)
currNode.getEdges().forEach(edge => {
edgeQueue.insert(edge);
});
while (!edgeQueue.isEmpty()) {
// 选取与已加入的结点之间边权最小的结点
// console.log(edgeQueue.list.map(edge => edge.getModel().weight))
let currEdge = edgeQueue.delMin();
const source = currEdge.getSource()
const target = currEdge.getTarget()
if (visited.has(source) && visited.has(target)) continue
selectedEdges.push(currEdge);
if (!visited.has(source)) {
visited.add(source)
source.getEdges().forEach(edge => {
edgeQueue.insert(edge);
});
}
if (!visited.has(target)) {
visited.add(target)
target.getEdges().forEach(edge => {
edgeQueue.insert(edge);
});
}
}
return selectedEdges;
}
/**
* Kruskal algorithm O(E*logE), E: 边的数量
* refer: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
* @param graph
* @param weight
* @return IEdge[] MST的边的数组
*/
const kruskalMST = (graph: IGraph, weight?: string): IEdge[] => {
const selectedEdges = []
if (graph.getNodes().length === 0) {
return selectedEdges;
}
// 若指定weight则将所有的边按权值从小到大排序
const edges = graph.getEdges().map(edge => edge)
if (weight) {
edges.sort((a, b) => {
return a.getModel()[weight] as number - (b.getModel()[weight] as number);
})
}
let disjointSet = new UnionFind(graph.getNodes().map(n => n.get('id')));
// 从权值最小的边开始如果这条边连接的两个节点于图G中不在同一个连通分量中则添加这条边
// 直到遍历完所有点或边
while (edges.length > 0) {
let curEdge = edges.shift();
let source = curEdge.getSource().get('id');
let target = curEdge.getTarget().get('id');
if (!disjointSet.connected(source, target)) {
selectedEdges.push(curEdge)
disjointSet.union(source, target);
}
}
return selectedEdges;
}
/**
*
* refer: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
* @param graph
* @param weight
* @param algo 'prim' | 'kruskal'
* @return IEdge[] MST的边的数组
*/
export default function mst(graph: IGraph, weight?: string, algo?: string) {
const algos = {
'prim': primMST,
'kruskal': kruskalMST,
}
if (!algo) return kruskalMST(graph, weight)
return algos[algo](graph, weight)
}

View File

@ -0,0 +1,73 @@
const defaultCompare = (a, b) => { return a - b }
export default class MinBinaryHeap {
list: any[];
compareFn: (a: any, b: any) => number;
constructor(compareFn = defaultCompare) {
this.compareFn = compareFn;
this.list = [];
}
getLeft(index) {
return 2 * index + 1;
}
getRight(index) {
return 2 * index + 2;
}
getParent(index) {
if (index === 0) {
return null;
}
return Math.floor((index - 1) / 2);
}
isEmpty() {
return this.list.length <= 0;
}
top() {
return this.isEmpty() ? undefined : this.list[0];
}
delMin() {
let top = this.top()
const bottom = this.list.pop();
if (this.list.length > 0) {
this.list[0] = bottom
this.moveDown(0)
}
return top
}
insert(value) {
if (value !== null) {
this.list.push(value);
const index = this.list.length - 1;
this.moveUp(index);
return true;
}
return false;
}
moveUp(index) {
let parent = this.getParent(index);
while (index && index > 0 && this.compareFn(this.list[parent], this.list[index]) > 0) {
// swap
const tmp = this.list[parent]
this.list[parent] = this.list[index]
this.list[index] = tmp
// [this.list[index], this.list[parent]] = [this.list[parent], this.list[index]]
index = parent;
parent = this.getParent(index);
}
}
moveDown(index) {
let element = index;
const left = this.getLeft(index);
const right = this.getRight(index);
const size = this.list.length;
if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) {
element = left;
} else if (right !== null && right < size && this.compareFn(this.list[element], this.list[right]) > 0) {
element = right;
}
if (index !== element) {
[this.list[index], this.list[element]] = [this.list[element], this.list[index]]
this.moveDown(element);
}
}
}

View File

@ -0,0 +1,43 @@
/**
* Disjoint set to support quick union
*/
export default class UnionFind {
count: number;
parent: {};
constructor(items: (number | string)[]) {
this.count = items.length;
this.parent = {};
for (let i of items) {
this.parent[i] = i;
}
}
// find the root of the item
find(item) {
while (this.parent[item] !== item) {
item = this.parent[item];
}
return item;
}
union(a, b) {
let rootA = this.find(a);
let rootB = this.find(b);
if (rootA === rootB) return;
// make the element with smaller root the parent
if (rootA < rootB) {
if (this.parent[b] != b) this.union(this.parent[b], a);
this.parent[b] = this.parent[a];
} else {
if (this.parent[a] != a) this.union(this.parent[a], b);
this.parent[a] = this.parent[b];
}
}
// whether a and b are connected, i.e. a and b have the same root
connected(a, b) {
return this.find(a) === this.find(b);
}
}

View File

@ -0,0 +1,130 @@
import G6, { Algorithm } from '../../../src';
const { minimumSpanningTree } = Algorithm
const div = document.createElement('div');
div.id = 'container';
document.body.appendChild(div);
const data = {
nodes: [
{
id: 'A'
},
{
id: 'B'
},
{
id: 'C'
},
{
id: 'D'
},
{
id: 'E'
},
{
id: 'F'
},
{
id: 'G'
},
],
edges: [
{
source: 'A',
target: 'B',
weight: 1,
},
{
source: 'B',
target: 'C',
weight: 1,
},
{
source: 'A',
target: 'C',
weight: 2,
},
{
source: 'D',
target: 'A',
weight: 3,
},
{
source: 'D',
target: 'E',
weight: 4,
},
{
source: 'E',
target: 'F',
weight: 2,
},
{
source: 'F',
target: 'D',
weight: 3,
}
]
}
data.nodes.forEach(node => node['label'] = node.id)
data.edges.forEach(edge => edge['label'] = edge.weight)
describe('minimumSpanningTree', () => {
const graph = new G6.Graph({
container: 'container',
width: 500,
height: 500,
layout: {
type: 'force'
},
modes: {
default: ['drag-node']
},
defaultNode: {
labelCfg: {
style: {
fontSize: 12,
},
},
},
defaultEdge: {
style: {
endArrow: true,
},
labelCfg: {
style: {
fontSize: 12,
},
},
},
edgeStateStyles: {
mst: {
stroke: 'red'
}
}
})
graph.data(data)
graph.render()
it('test kruskal algorithm', () => {
let result = minimumSpanningTree(graph, 'weight')
let totalWeight = 0
for (let edge of result) {
graph.setItemState(edge, 'mst', true);
totalWeight += edge.getModel()['weight']
}
expect(totalWeight).toEqual(10);
});
it('test prim algorithm', () => {
let result = minimumSpanningTree(graph, 'weight', 'prim')
let totalWeight = 0
for (let edge of result) {
graph.setItemState(edge, 'mst', true);
totalWeight += edge.getModel()['weight']
}
expect(totalWeight).toEqual(10);
});
});