aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAsbjørn Sloth Tønnesen <ast@2e8.dk>2019-07-16 19:08:00 +0000
committerAsbjørn Sloth Tønnesen <ast@2e8.dk>2019-07-16 19:08:00 +0000
commitfc92617f54327914b037f150e27e68235798b8ae (patch)
tree0220365d983aa8c2f0f9d370e4919fb0a0f64b1f
parent23732642cb8c27de229a52ba201af2809aac6ddd (diff)
downloadlua-inet-fc92617f54327914b037f150e27e68235798b8ae.tar.gz
lua-inet-fc92617f54327914b037f150e27e68235798b8ae.tar.xz
lua-inet-fc92617f54327914b037f150e27e68235798b8ae.zip
refactor inet.new()
Signed-off-by: Asbjørn Sloth Tønnesen <ast@2e8.dk>
-rw-r--r--lua/inet/init.lua298
1 files changed, 184 insertions, 114 deletions
diff --git a/lua/inet/init.lua b/lua/inet/init.lua
index 3186d30..836e94b 100644
--- a/lua/inet/init.lua
+++ b/lua/inet/init.lua
@@ -3,18 +3,24 @@
local bit32 = require 'bit32'
-local inet, inet4, inet6
+local format = string.format
+
+local lshift = bit32.lshift
+local rshift = bit32.rshift
+local band = bit32.band
+local replace = bit32.replace
+local bxor = bit32.bxor
local mt2fam = {}
-inet = {}
+local inet = {}
inet.__index = inet
-inet4 = setmetatable({}, inet)
+local inet4 = setmetatable({}, inet)
inet4.__index = inet4
mt2fam[inet4] = 4
-inet6 = setmetatable({}, inet)
+local inet6 = setmetatable({}, inet)
inet6.__index = inet6
mt2fam[inet6] = 6
@@ -38,15 +44,6 @@ local function is_inet(t)
return mt == inet4 or mt == inet6
end
-function inet.new(ip, mask)
- local ipv6 = string.find(ip, ':', 1, true)
- if ipv6 then
- return inet6.new(ip, mask)
- else
- return inet4.new(ip, mask)
- end
-end
-
function inet:__len()
local mask = self.mask
if mask == nil then return 0 end -- make metatable inspectable
@@ -60,12 +57,6 @@ function inet:family()
return assert(mt2fam[mt])
end
-local lshift = bit32.lshift
-local rshift = bit32.rshift
-local band = bit32.band
-local replace = bit32.replace
-local bxor = bit32.bxor
-
local ipv4_parser
local ipv6_parser
do
@@ -76,6 +67,7 @@ do
local digit = R('09')
+ local ipv4addr
do
local dot = S('.')
local zero = S('0')
@@ -86,12 +78,12 @@ do
local octet32 = R('04') * digit
local octet325 = S('5') * R('05')
local octet3 = octet31 + (S('2') * (octet32 + octet325))
- local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0)
- local ipv4 = octet * dot * octet * dot * octet * dot * octet
+ local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) / tonumber
+ ipv4addr = octet * dot * octet * dot * octet * dot * octet
local mask12 = R('12') * digit
local mask3 = S('3') * R('02')
local netmask = S('/') * C(mask12 + mask3 + digit)
- ipv4_parser = ipv4 * (netmask + C('')) * -1
+ ipv4_parser = ipv4addr * (netmask + Cc()) * -1
end
do
@@ -102,48 +94,185 @@ do
local colcol = C(col * col)
local picol = piece * col
local colpi = col * piece
- local full = picol * picol * picol * picol * picol * picol * picol * piece
- local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*piece)^-1
+ local ipv4embed = ipv4addr / function(a,b,c,d)
+ return lshift(a, 8) + b, lshift(c, 8) + d
+ end
+ local last32bits = (ipv4embed + (picol * piece))
+ local full = picol * picol * picol * picol * picol * picol * last32bits
+ local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*(ipv4embed+piece))^-1
local netmask = S('/') * C((digit^-3)) / tonumber
- ipv6_parser = Ct(full + partial) * ((netmask + C(''))^-1) * -1
+ local pieces = full + partial
+ ipv6_parser = Ct(pieces) * ((netmask + Cc())^-1) * -1
end
end
-local function parse4(ipstr)
+local function build_bip(o1, o2, o3, o4)
+ return lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4
+end
+
+local function inet4_from_string(ipstr)
local o1, o2, o3, o4, mask = ipv4_parser:match(ipstr)
- if o1 == nil then return nil end
+ if not o1 then return nil, 'parse error' end
- local bip = lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4
+ local bip = build_bip(o1, o2, o3, o4)
return bip, tonumber(mask)
end
-function inet4.new(ip, mask)
- local bip, ourmask
- if type(ip) == 'string' then
- bip, ourmask = parse4(ip)
- if bip == nil then
- return nil
+local function inet4_from_number(bip)
+ return bip
+end
+
+local function inet4_from_table(t)
+ if #t ~= 4 then return nil, 'invalid length' end
+ for i=1,4 do
+ local v = t[i]
+ if type(v) ~= 'number' then return nil, 'invalid number' end
+ if v < 0 or v > 255 then return nil, 'octet out of range' end
+ end
+ return build_bip(t[1], t[2], t[3], t[4])
+end
+
+local inet4_constructors = {
+ string = inet4_from_string,
+ number = inet4_from_number,
+ table = inet4_from_table,
+}
+
+local function inet6_from_string(ipstr)
+ local pcs, netmask = ipv6_parser:match(ipstr)
+ if not pcs then return nil, 'parse error' end
+ if #pcs > 8 then return nil, 'too many pieces' end
+ local zero_pieces = 8 - #pcs
+ for i=1,#pcs do
+ if pcs[i] == '::' then
+ pcs[i] = 0
+ for j=#pcs,i,-1 do
+ pcs[j+zero_pieces] = pcs[j]
+ end
+ for j=1,zero_pieces do
+ pcs[i+j] = 0
+ end
end
- elseif type(ip) == 'number' then
- bip = ip
end
- if mask then
- if type(mask) == 'number' and mask >= 0 and mask <= 32 then
- ourmask = mask
+ if #pcs > 8 then return nil, 'too many pieces' end
+ if netmask ~= nil and netmask > 128 then
+ return nil, 'invalid netmask'
+ end
+ return pcs, netmask
+end
+
+local function inet6_from_table(t)
+ if #t ~= 8 then return nil, 'invalid length' end
+ for i=1,8 do
+ local v = t[i]
+ if type(v) ~= 'number' then return nil, 'invalid number' end
+ if v < 0 or v > 0xffff then return nil, 'octet out of range' end
+ end
+ return { t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8] }
+end
+
+local inet6_constructors = {
+ string = inet6_from_string,
+ table = inet6_from_table,
+}
+
+local function decide_mask(from_ip, override, high)
+ local newmask = from_ip
+ if override then
+ if type(override) == 'number' and override >= 0 and override <= high then
+ if from_ip ~= nil then
+ return nil, 'multiple masks supplied'
+ end
+ newmask = override
else
- error('invalid mask')
+ return nil, 'invalid mask'
end
else
- if not ourmask then
- ourmask = 32
+ if not newmask then
+ newmask = high
end
end
+ return newmask
+end
+
+local function generic_new(constructors, high, ip, mask)
+ local type_ip = type(ip)
+ local constructor = constructors[type_ip]
+ if not constructor then
+ return nil, 'invalid ip argument'
+ end
+ local iir, ourmask = constructor(ip)
+ if not iir then
+ return nil, ourmask
+ end
+ local outmask, err = decide_mask(ourmask, mask, high)
+ if not outmask then return nil, err end
+
+ return iir, outmask
+end
+
+local function new_inet4(ip, mask)
+ local bip, outmask = generic_new(inet4_constructors, 32, ip, mask)
+ if not bip then return nil, outmask end
return setmetatable({
bip = bip,
- mask = ourmask,
+ mask = outmask,
}, inet4)
end
+local function new_inet6(ip, mask)
+ local pcs, outmask = generic_new(inet6_constructors, 128, ip, mask)
+ if not pcs then return nil, outmask end
+
+ local r = setmetatable({
+ pcs = pcs,
+ mask = outmask,
+ }, inet6)
+
+ -- ensure that the result is balanced
+ if not r:is_balanced() then
+ r:balance()
+ return nil, tostring(r)..' unbalanced'
+ end
+
+ return r
+end
+
+local function new_inet(ip, mask)
+ local is_ipv6
+ local type_ip = type(ip)
+ if type_ip == 'string' then
+ is_ipv6 = string.find(ip, ':', 1, true)
+ elseif type_ip == 'number' then
+ is_ipv6 = false
+ elseif is_inet4(ip) then
+ mask = mask or #ip
+ ip = ip.bip
+ is_ipv6 = false
+ elseif is_inet6(ip) then
+ mask = mask or #ip
+ ip = ip.pcs
+ is_ipv6 = true
+ elseif type_ip == 'table' then
+ local n = #ip
+ if n == 8 then
+ is_ipv6 = true
+ elseif n == 4 then
+ is_ipv6 = false
+ else
+ return nil, 'invalid table'
+ end
+ else
+ return nil, 'invalid ip type'
+ end
+
+ if is_ipv6 then
+ return new_inet6(ip, mask)
+ else
+ return new_inet4(ip, mask)
+ end
+end
+
local function tostr4(self, withmask)
-- return human readable
local bip, mask = self.bip, self.mask
@@ -172,24 +301,24 @@ function inet4:cidrstring()
end
function inet4:__add(n)
- return inet4.new(self.bip + n, self.mask)
+ return new_inet4(self.bip + n, self.mask)
end
function inet4:__sub(n)
- return inet4.new(self.bip - n, self.mask)
+ return new_inet4(self.bip - n, self.mask)
end
function inet4:__mul(n)
local new = self.bip + (n * math.pow(2, 32 - self.mask))
- return inet4.new(new, self.mask)
+ return new_inet4(new, self.mask)
end
function inet4:__div(n)
- return inet4.new(self.bip, n)
+ return new_inet4(self.bip, n)
end
function inet4:__pow(n)
- return inet4.new(self.bip, self.mask + n)
+ return new_inet4(self.bip, self.mask + n)
end
function inet4:__lt(other)
@@ -222,12 +351,12 @@ end
function inet4:network()
local hostbits = 32 - self.mask
- return inet4.new(lshift(rshift(self.bip, hostbits), hostbits), self.mask)
+ return new_inet4(lshift(rshift(self.bip, hostbits), hostbits), self.mask)
end
function inet4:netmask()
local hostbits = 32 - self.mask
- return inet4.new(replace(0xffffffff, 0, 0, hostbits), 32)
+ return new_inet4(replace(0xffffffff, 0, 0, hostbits), 32)
end
function inet4:flip()
@@ -236,66 +365,7 @@ function inet4:flip()
if mask == 0 then return nil end
local hostbits = 32 - mask
local flipbit = lshift(1, hostbits)
- return inet4.new(bxor(self.bip, flipbit), mask)
-end
-
-local function parse6(ipstr)
- local pcs, netmask = ipv6_parser:match(ipstr)
- if not pcs then return nil end
- if #pcs > 8 then return nil, 'too many pieces' end
- local zero_pieces = 8 - #pcs
- for i=1,#pcs do
- if pcs[i] == '::' then
- pcs[i] = 0
- for j=1,#pcs-i do
- pcs[i+j+zero_pieces] = pcs[i+j]
- end
- for j=1,zero_pieces do
- pcs[i+j] = 0
- end
- end
- end
- if #pcs > 8 then return nil, 'too many pieces' end
- if netmask == '' then
- netmask = 128
- elseif netmask > 128 then
- return nil, 'invalid netmask'
- end
- return pcs, netmask
-end
-
-function inet6.new(ip, netmask)
- local pcs, err
- if type(ip) == 'string' then
- pcs, err = parse6(ip)
- if pcs == nil then
- return nil, err
- end
- if not netmask then
- netmask = err
- end
- elseif type(ip) == 'table' then
- pcs = { ip[1], ip[2], ip[3], ip[4],
- ip[5], ip[6], ip[7], ip[8] }
- if not netmask then
- netmask = 128
- end
- else
- return nil
- end
-
- local r = setmetatable({
- pcs = pcs,
- mask = netmask,
- }, inet6)
-
- -- ensure that the result is balanced
- if not r:is_balanced() then
- r:balance()
- return nil, tostring(r)..' unbalanced'
- end
-
- return r
+ return new_inet4(bxor(self.bip, flipbit), mask)
end
-- each ipv6 address is stored as eight pieces
@@ -415,7 +485,7 @@ function inet6:cidrstring()
end
function inet6:clone()
- return inet6.new(self.pcs, self.mask)
+ return new_inet6(self.pcs, self.mask)
end
function inet6:__eq(other)
@@ -433,11 +503,11 @@ function inet6:__eq(other)
end
function inet6:__div(n)
- return inet6.new(self.pcs, n)
+ return new_inet6(self.pcs, n)
end
function inet6:__pow(n)
- return inet6.new(self.pcs, self.mask + n)
+ return new_inet6(self.pcs, self.mask + n)
end
function inet6:__add(n)
@@ -466,7 +536,7 @@ function inet6:network()
newpcs[i] = lshift(rshift(pcs[i], netbitsleft), netbitsleft)
end
end
- return inet6.new(newpcs, netbits)
+ return new_inet6(newpcs, netbits)
end
function inet6:flip()