From fc92617f54327914b037f150e27e68235798b8ae Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Asbj=C3=B8rn=20Sloth=20T=C3=B8nnesen?= <ast@2e8.dk>
Date: Tue, 16 Jul 2019 19:08:00 +0000
Subject: refactor inet.new()
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

Signed-off-by: Asbjørn Sloth Tønnesen <ast@2e8.dk>
---
 lua/inet/init.lua | 298 +++++++++++++++++++++++++++++++++---------------------
 1 file changed, 184 insertions(+), 114 deletions(-)

(limited to 'lua/inet')

diff --git a/lua/inet/init.lua b/lua/inet/init.lua
index 3186d30..836e94b 100644
--- a/lua/inet/init.lua
+++ b/lua/inet/init.lua
@@ -3,18 +3,24 @@
 
 local bit32 = require 'bit32'
 
-local inet, inet4, inet6
+local format = string.format
+
+local lshift = bit32.lshift
+local rshift = bit32.rshift
+local band = bit32.band
+local replace = bit32.replace
+local bxor = bit32.bxor
 
 local mt2fam = {}
 
-inet = {}
+local inet = {}
 inet.__index = inet
 
-inet4 = setmetatable({}, inet)
+local inet4 = setmetatable({}, inet)
 inet4.__index = inet4
 mt2fam[inet4] = 4
 
-inet6 = setmetatable({}, inet)
+local inet6 = setmetatable({}, inet)
 inet6.__index = inet6
 mt2fam[inet6] = 6
 
@@ -38,15 +44,6 @@ local function is_inet(t)
 	return mt == inet4 or mt == inet6
 end
 
-function inet.new(ip, mask)
-	local ipv6 = string.find(ip, ':', 1, true)
-	if ipv6 then
-		return inet6.new(ip, mask)
-	else
-		return inet4.new(ip, mask)
-	end
-end
-
 function inet:__len()
 	local mask = self.mask
 	if mask == nil then return 0 end -- make metatable inspectable
@@ -60,12 +57,6 @@ function inet:family()
 	return assert(mt2fam[mt])
 end
 
-local lshift = bit32.lshift
-local rshift = bit32.rshift
-local band = bit32.band
-local replace = bit32.replace
-local bxor = bit32.bxor
-
 local ipv4_parser
 local ipv6_parser
 do
@@ -76,6 +67,7 @@ do
 
 	local digit = R('09')
 
+	local ipv4addr
 	do
 		local dot = S('.')
 		local zero = S('0')
@@ -86,12 +78,12 @@ do
 		local octet32 = R('04') * digit
 		local octet325 = S('5') * R('05')
 		local octet3 = octet31 + (S('2') * (octet32 + octet325))
-		local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0)
-		local ipv4 = octet * dot * octet * dot * octet * dot * octet
+		local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) / tonumber
+		ipv4addr = octet * dot * octet * dot * octet * dot * octet
 		local mask12 = R('12') * digit
 		local mask3 = S('3') * R('02')
 		local netmask = S('/') * C(mask12 + mask3 + digit)
-		ipv4_parser = ipv4 * (netmask + C('')) * -1
+		ipv4_parser = ipv4addr * (netmask + Cc()) * -1
 	end
 
 	do
@@ -102,48 +94,185 @@ do
 		local colcol = C(col * col)
 		local picol = piece * col
 		local colpi = col * piece
-		local full = picol * picol * picol * picol * picol * picol * picol * piece
-		local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*piece)^-1
+		local ipv4embed = ipv4addr / function(a,b,c,d)
+			return lshift(a, 8) + b, lshift(c, 8) + d
+		end
+		local last32bits = (ipv4embed + (picol * piece))
+		local full = picol * picol * picol * picol * picol * picol * last32bits
+		local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*(ipv4embed+piece))^-1
 		local netmask = S('/') * C((digit^-3)) / tonumber
-		ipv6_parser = Ct(full + partial) * ((netmask + C(''))^-1) * -1
+		local pieces = full + partial
+		ipv6_parser = Ct(pieces) * ((netmask + Cc())^-1) * -1
 	end
 end
 
-local function parse4(ipstr)
+local function build_bip(o1, o2, o3, o4)
+	return lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4
+end
+
+local function inet4_from_string(ipstr)
 	local o1, o2, o3, o4, mask = ipv4_parser:match(ipstr)
-	if o1 == nil then return nil end
+	if not o1 then return nil, 'parse error' end
 
-	local bip = lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4
+	local bip = build_bip(o1, o2, o3, o4)
 	return bip, tonumber(mask)
 end
 
-function inet4.new(ip, mask)
-	local bip, ourmask
-	if type(ip) == 'string' then
-		bip, ourmask = parse4(ip)
-		if bip == nil then
-			return nil
+local function inet4_from_number(bip)
+	return bip
+end
+
+local function inet4_from_table(t)
+	if #t ~= 4 then return nil, 'invalid length' end
+	for i=1,4 do
+		local v = t[i]
+		if type(v) ~= 'number' then return nil, 'invalid number' end
+		if v < 0 or v > 255 then return nil, 'octet out of range' end
+	end
+	return build_bip(t[1], t[2], t[3], t[4])
+end
+
+local inet4_constructors = {
+	string = inet4_from_string,
+	number = inet4_from_number,
+	table  = inet4_from_table,
+}
+
+local function inet6_from_string(ipstr)
+	local pcs, netmask = ipv6_parser:match(ipstr)
+	if not pcs then return nil, 'parse error' end
+	if #pcs > 8 then return nil, 'too many pieces' end
+	local zero_pieces = 8 - #pcs
+	for i=1,#pcs do
+		if pcs[i] == '::' then
+			pcs[i] = 0
+			for j=#pcs,i,-1 do
+				pcs[j+zero_pieces] = pcs[j]
+			end
+			for j=1,zero_pieces do
+				pcs[i+j] = 0
+			end
 		end
-	elseif type(ip) == 'number' then
-		bip = ip
 	end
-	if mask then
-		if type(mask) == 'number' and mask >= 0 and mask <= 32 then
-			ourmask = mask
+	if #pcs > 8 then return nil, 'too many pieces' end
+	if netmask ~= nil and netmask > 128 then
+		return nil, 'invalid netmask'
+	end
+	return pcs, netmask
+end
+
+local function inet6_from_table(t)
+	if #t ~= 8 then return nil, 'invalid length' end
+	for i=1,8 do
+		local v = t[i]
+		if type(v) ~= 'number' then return nil, 'invalid number' end
+		if v < 0 or v > 0xffff then return nil, 'octet out of range' end
+	end
+	return { t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8] }
+end
+
+local inet6_constructors = {
+	string = inet6_from_string,
+	table  = inet6_from_table,
+}
+
+local function decide_mask(from_ip, override, high)
+	local newmask = from_ip
+	if override then
+		if type(override) == 'number' and override >= 0 and override <= high then
+			if from_ip ~= nil then
+				return nil, 'multiple masks supplied'
+			end
+			newmask = override
 		else
-			error('invalid mask')
+			return nil, 'invalid mask'
 		end
 	else
-		if not ourmask then
-			ourmask = 32
+		if not newmask then
+			newmask = high
 		end
 	end
+	return newmask
+end
+
+local function generic_new(constructors, high, ip, mask)
+	local type_ip = type(ip)
+	local constructor = constructors[type_ip]
+	if not constructor then
+		return nil, 'invalid ip argument'
+	end
+	local iir, ourmask = constructor(ip)
+	if not iir then
+		return nil, ourmask
+	end
+	local outmask, err = decide_mask(ourmask, mask, high)
+	if not outmask then return nil, err end
+
+	return iir, outmask
+end
+
+local function new_inet4(ip, mask)
+	local bip, outmask = generic_new(inet4_constructors, 32, ip, mask)
+	if not bip then return nil, outmask end
 	return setmetatable({
 		bip = bip,
-		mask = ourmask,
+		mask = outmask,
 	}, inet4)
 end
 
+local function new_inet6(ip, mask)
+	local pcs, outmask = generic_new(inet6_constructors, 128, ip, mask)
+	if not pcs then return nil, outmask end
+
+	local r = setmetatable({
+		pcs = pcs,
+		mask = outmask,
+	}, inet6)
+
+	-- ensure that the result is balanced
+	if not r:is_balanced() then
+		r:balance()
+		return nil, tostring(r)..' unbalanced'
+	end
+
+	return r
+end
+
+local function new_inet(ip, mask)
+	local is_ipv6
+	local type_ip = type(ip)
+	if type_ip == 'string' then
+		is_ipv6 = string.find(ip, ':', 1, true)
+	elseif type_ip == 'number' then
+		is_ipv6 = false
+	elseif is_inet4(ip) then
+		mask = mask or #ip
+		ip = ip.bip
+		is_ipv6 = false
+	elseif is_inet6(ip) then
+		mask = mask or #ip
+		ip = ip.pcs
+		is_ipv6 = true
+	elseif type_ip == 'table' then
+		local n = #ip
+		if n == 8 then
+			is_ipv6 = true
+		elseif n == 4 then
+			is_ipv6 = false
+		else
+			return nil, 'invalid table'
+		end
+	else
+		return nil, 'invalid ip type'
+	end
+
+	if is_ipv6 then
+		return new_inet6(ip, mask)
+	else
+		return new_inet4(ip, mask)
+	end
+end
+
 local function tostr4(self, withmask)
 	-- return human readable
 	local bip, mask = self.bip, self.mask
@@ -172,24 +301,24 @@ function inet4:cidrstring()
 end
 
 function inet4:__add(n)
-	return inet4.new(self.bip + n, self.mask)
+	return new_inet4(self.bip + n, self.mask)
 end
 
 function inet4:__sub(n)
-	return inet4.new(self.bip - n, self.mask)
+	return new_inet4(self.bip - n, self.mask)
 end
 
 function inet4:__mul(n)
 	local new = self.bip + (n * math.pow(2, 32 - self.mask))
-	return inet4.new(new, self.mask)
+	return new_inet4(new, self.mask)
 end
 
 function inet4:__div(n)
-	return inet4.new(self.bip, n)
+	return new_inet4(self.bip, n)
 end
 
 function inet4:__pow(n)
-	return inet4.new(self.bip, self.mask + n)
+	return new_inet4(self.bip, self.mask + n)
 end
 
 function inet4:__lt(other)
@@ -222,12 +351,12 @@ end
 
 function inet4:network()
 	local hostbits = 32 - self.mask
-	return inet4.new(lshift(rshift(self.bip, hostbits), hostbits), self.mask)
+	return new_inet4(lshift(rshift(self.bip, hostbits), hostbits), self.mask)
 end
 
 function inet4:netmask()
 	local hostbits = 32 - self.mask
-	return inet4.new(replace(0xffffffff, 0, 0, hostbits), 32)
+	return new_inet4(replace(0xffffffff, 0, 0, hostbits), 32)
 end
 
 function inet4:flip()
@@ -236,66 +365,7 @@ function inet4:flip()
 	if mask == 0 then return nil end
 	local hostbits = 32 - mask
 	local flipbit = lshift(1, hostbits)
-	return inet4.new(bxor(self.bip, flipbit), mask)
-end
-
-local function parse6(ipstr)
-	local pcs, netmask = ipv6_parser:match(ipstr)
-	if not pcs then return nil end
-	if #pcs > 8 then return nil, 'too many pieces' end
-	local zero_pieces = 8 - #pcs
-	for i=1,#pcs do
-		if pcs[i] == '::' then
-			pcs[i] = 0
-			for j=1,#pcs-i do
-				pcs[i+j+zero_pieces] = pcs[i+j]
-			end
-			for j=1,zero_pieces do
-				pcs[i+j] = 0
-			end
-		end
-	end
-	if #pcs > 8 then return nil, 'too many pieces' end
-	if netmask == '' then
-		netmask = 128
-	elseif netmask > 128 then
-		return nil, 'invalid netmask'
-	end
-	return pcs, netmask
-end
-
-function inet6.new(ip, netmask)
-	local pcs, err
-	if type(ip) == 'string' then
-		pcs, err = parse6(ip)
-		if pcs == nil then
-			return nil, err
-		end
-		if not netmask then
-			netmask = err
-		end
-	elseif type(ip) == 'table' then
-		pcs = { ip[1], ip[2], ip[3], ip[4],
-		        ip[5], ip[6], ip[7], ip[8] }
-		if not netmask then
-			netmask = 128
-		end
-	else
-		return nil
-	end
-
-	local r = setmetatable({
-		pcs = pcs,
-		mask = netmask,
-	}, inet6)
-
-	-- ensure that the result is balanced
-	if not r:is_balanced() then
-		r:balance()
-		return nil, tostring(r)..' unbalanced'
-	end
-
-	return r
+	return new_inet4(bxor(self.bip, flipbit), mask)
 end
 
 -- each ipv6 address is stored as eight pieces
@@ -415,7 +485,7 @@ function inet6:cidrstring()
 end
 
 function inet6:clone()
-	return inet6.new(self.pcs, self.mask)
+	return new_inet6(self.pcs, self.mask)
 end
 
 function inet6:__eq(other)
@@ -433,11 +503,11 @@ function inet6:__eq(other)
 end
 
 function inet6:__div(n)
-	return inet6.new(self.pcs, n)
+	return new_inet6(self.pcs, n)
 end
 
 function inet6:__pow(n)
-	return inet6.new(self.pcs, self.mask + n)
+	return new_inet6(self.pcs, self.mask + n)
 end
 
 function inet6:__add(n)
@@ -466,7 +536,7 @@ function inet6:network()
 			newpcs[i] = lshift(rshift(pcs[i], netbitsleft), netbitsleft)
 		end
 	end
-	return inet6.new(newpcs, netbits)
+	return new_inet6(newpcs, netbits)
 end
 
 function inet6:flip()
-- 
cgit v1.2.1