aboutsummaryrefslogtreecommitdiffstats
path: root/lua/inet/set.lua
blob: 5999f65b9637299ed13bb51af79890e5f053a5f4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
local M = {}

local function table_compact(t, n)
	-- remove nil entries, and reorder
	local i = 0
	for j=1,n do
		if t[j] then
			if i > 0 then
				t[i] = t[j]
				i = i + 1
			end
		elseif i == 0 then
			i = j
		end
	end
	if i > 0 then
		for j=i,n do
			t[j] = nil
		end
	end
end

function M.aggregate(t)
	local flag = true
	local n = #t
	for i=1,n do
		t[i] = t[i]:network()
	end
	while flag do -- loop until no aggregatable addresses are found
		flag = false
		for i=1,n do
			local ia = t[i]
			if ia then
				local ib = ia:flip() -- counterpart
				for j=1,n do
					if j ~= i then
						if ia == t[j] then
							-- duplicate found
							t[j] = nil
							flag = true
						elseif t[j] == ib then
							-- counterpart found, aggregating
							t[i] = (ia ^ -1):network()
							t[j] = nil
							flag = true
						end
					end
				end
			end
		end
	end
	table_compact(t, n)
end

local function has(set, addr)
	assert(set)
	for i=1,#set do
		local elem = set[i]
		if elem >= addr then
			local exclude = set.exclude
			if exclude then
				return not has(exclude, addr)
			else
				return true
			end
		end
	end
	return false
end

function M.iterator(set)
	local excl = set.exclude
	local i = 1
	if #set < 1 then return nil, 'empty set' end
	local addr = set[i]
	local net = set[i]:network()
	local function iter()
		if not addr then return end
		local ret
		ret = addr/32
		addr = addr + 1
		if addr:network() ~= net then
			i = i + 1
			addr = set[i]
			if addr then
				net = set[i]:network()
			end
		end
		if has(excl, ret) then
			return iter()
		end
		return ret
	end

	return iter
end

function M.loopiterator(set)
	local orig_iter = M.iterator(set)
	local function iter()
		local addr = orig_iter()
		if not addr then
			orig_iter = M.iterator(set)
			addr = orig_iter()
		end
		return addr
	end
	return iter
end

return M