From fc92617f54327914b037f150e27e68235798b8ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Asbj=C3=B8rn=20Sloth=20T=C3=B8nnesen?= Date: Tue, 16 Jul 2019 19:08:00 +0000 Subject: refactor inet.new() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Asbjørn Sloth Tønnesen --- lua/inet/init.lua | 298 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 184 insertions(+), 114 deletions(-) diff --git a/lua/inet/init.lua b/lua/inet/init.lua index 3186d30..836e94b 100644 --- a/lua/inet/init.lua +++ b/lua/inet/init.lua @@ -3,18 +3,24 @@ local bit32 = require 'bit32' -local inet, inet4, inet6 +local format = string.format + +local lshift = bit32.lshift +local rshift = bit32.rshift +local band = bit32.band +local replace = bit32.replace +local bxor = bit32.bxor local mt2fam = {} -inet = {} +local inet = {} inet.__index = inet -inet4 = setmetatable({}, inet) +local inet4 = setmetatable({}, inet) inet4.__index = inet4 mt2fam[inet4] = 4 -inet6 = setmetatable({}, inet) +local inet6 = setmetatable({}, inet) inet6.__index = inet6 mt2fam[inet6] = 6 @@ -38,15 +44,6 @@ local function is_inet(t) return mt == inet4 or mt == inet6 end -function inet.new(ip, mask) - local ipv6 = string.find(ip, ':', 1, true) - if ipv6 then - return inet6.new(ip, mask) - else - return inet4.new(ip, mask) - end -end - function inet:__len() local mask = self.mask if mask == nil then return 0 end -- make metatable inspectable @@ -60,12 +57,6 @@ function inet:family() return assert(mt2fam[mt]) end -local lshift = bit32.lshift -local rshift = bit32.rshift -local band = bit32.band -local replace = bit32.replace -local bxor = bit32.bxor - local ipv4_parser local ipv6_parser do @@ -76,6 +67,7 @@ do local digit = R('09') + local ipv4addr do local dot = S('.') local zero = S('0') @@ -86,12 +78,12 @@ do local octet32 = R('04') * digit local octet325 = S('5') * R('05') local octet3 = octet31 + (S('2') * (octet32 + octet325)) - local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) - local ipv4 = octet * dot * octet * dot * octet * dot * octet + local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) / tonumber + ipv4addr = octet * dot * octet * dot * octet * dot * octet local mask12 = R('12') * digit local mask3 = S('3') * R('02') local netmask = S('/') * C(mask12 + mask3 + digit) - ipv4_parser = ipv4 * (netmask + C('')) * -1 + ipv4_parser = ipv4addr * (netmask + Cc()) * -1 end do @@ -102,48 +94,185 @@ do local colcol = C(col * col) local picol = piece * col local colpi = col * piece - local full = picol * picol * picol * picol * picol * picol * picol * piece - local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*piece)^-1 + local ipv4embed = ipv4addr / function(a,b,c,d) + return lshift(a, 8) + b, lshift(c, 8) + d + end + local last32bits = (ipv4embed + (picol * piece)) + local full = picol * picol * picol * picol * picol * picol * last32bits + local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*(ipv4embed+piece))^-1 local netmask = S('/') * C((digit^-3)) / tonumber - ipv6_parser = Ct(full + partial) * ((netmask + C(''))^-1) * -1 + local pieces = full + partial + ipv6_parser = Ct(pieces) * ((netmask + Cc())^-1) * -1 end end -local function parse4(ipstr) +local function build_bip(o1, o2, o3, o4) + return lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4 +end + +local function inet4_from_string(ipstr) local o1, o2, o3, o4, mask = ipv4_parser:match(ipstr) - if o1 == nil then return nil end + if not o1 then return nil, 'parse error' end - local bip = lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4 + local bip = build_bip(o1, o2, o3, o4) return bip, tonumber(mask) end -function inet4.new(ip, mask) - local bip, ourmask - if type(ip) == 'string' then - bip, ourmask = parse4(ip) - if bip == nil then - return nil +local function inet4_from_number(bip) + return bip +end + +local function inet4_from_table(t) + if #t ~= 4 then return nil, 'invalid length' end + for i=1,4 do + local v = t[i] + if type(v) ~= 'number' then return nil, 'invalid number' end + if v < 0 or v > 255 then return nil, 'octet out of range' end + end + return build_bip(t[1], t[2], t[3], t[4]) +end + +local inet4_constructors = { + string = inet4_from_string, + number = inet4_from_number, + table = inet4_from_table, +} + +local function inet6_from_string(ipstr) + local pcs, netmask = ipv6_parser:match(ipstr) + if not pcs then return nil, 'parse error' end + if #pcs > 8 then return nil, 'too many pieces' end + local zero_pieces = 8 - #pcs + for i=1,#pcs do + if pcs[i] == '::' then + pcs[i] = 0 + for j=#pcs,i,-1 do + pcs[j+zero_pieces] = pcs[j] + end + for j=1,zero_pieces do + pcs[i+j] = 0 + end end - elseif type(ip) == 'number' then - bip = ip end - if mask then - if type(mask) == 'number' and mask >= 0 and mask <= 32 then - ourmask = mask + if #pcs > 8 then return nil, 'too many pieces' end + if netmask ~= nil and netmask > 128 then + return nil, 'invalid netmask' + end + return pcs, netmask +end + +local function inet6_from_table(t) + if #t ~= 8 then return nil, 'invalid length' end + for i=1,8 do + local v = t[i] + if type(v) ~= 'number' then return nil, 'invalid number' end + if v < 0 or v > 0xffff then return nil, 'octet out of range' end + end + return { t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8] } +end + +local inet6_constructors = { + string = inet6_from_string, + table = inet6_from_table, +} + +local function decide_mask(from_ip, override, high) + local newmask = from_ip + if override then + if type(override) == 'number' and override >= 0 and override <= high then + if from_ip ~= nil then + return nil, 'multiple masks supplied' + end + newmask = override else - error('invalid mask') + return nil, 'invalid mask' end else - if not ourmask then - ourmask = 32 + if not newmask then + newmask = high end end + return newmask +end + +local function generic_new(constructors, high, ip, mask) + local type_ip = type(ip) + local constructor = constructors[type_ip] + if not constructor then + return nil, 'invalid ip argument' + end + local iir, ourmask = constructor(ip) + if not iir then + return nil, ourmask + end + local outmask, err = decide_mask(ourmask, mask, high) + if not outmask then return nil, err end + + return iir, outmask +end + +local function new_inet4(ip, mask) + local bip, outmask = generic_new(inet4_constructors, 32, ip, mask) + if not bip then return nil, outmask end return setmetatable({ bip = bip, - mask = ourmask, + mask = outmask, }, inet4) end +local function new_inet6(ip, mask) + local pcs, outmask = generic_new(inet6_constructors, 128, ip, mask) + if not pcs then return nil, outmask end + + local r = setmetatable({ + pcs = pcs, + mask = outmask, + }, inet6) + + -- ensure that the result is balanced + if not r:is_balanced() then + r:balance() + return nil, tostring(r)..' unbalanced' + end + + return r +end + +local function new_inet(ip, mask) + local is_ipv6 + local type_ip = type(ip) + if type_ip == 'string' then + is_ipv6 = string.find(ip, ':', 1, true) + elseif type_ip == 'number' then + is_ipv6 = false + elseif is_inet4(ip) then + mask = mask or #ip + ip = ip.bip + is_ipv6 = false + elseif is_inet6(ip) then + mask = mask or #ip + ip = ip.pcs + is_ipv6 = true + elseif type_ip == 'table' then + local n = #ip + if n == 8 then + is_ipv6 = true + elseif n == 4 then + is_ipv6 = false + else + return nil, 'invalid table' + end + else + return nil, 'invalid ip type' + end + + if is_ipv6 then + return new_inet6(ip, mask) + else + return new_inet4(ip, mask) + end +end + local function tostr4(self, withmask) -- return human readable local bip, mask = self.bip, self.mask @@ -172,24 +301,24 @@ function inet4:cidrstring() end function inet4:__add(n) - return inet4.new(self.bip + n, self.mask) + return new_inet4(self.bip + n, self.mask) end function inet4:__sub(n) - return inet4.new(self.bip - n, self.mask) + return new_inet4(self.bip - n, self.mask) end function inet4:__mul(n) local new = self.bip + (n * math.pow(2, 32 - self.mask)) - return inet4.new(new, self.mask) + return new_inet4(new, self.mask) end function inet4:__div(n) - return inet4.new(self.bip, n) + return new_inet4(self.bip, n) end function inet4:__pow(n) - return inet4.new(self.bip, self.mask + n) + return new_inet4(self.bip, self.mask + n) end function inet4:__lt(other) @@ -222,12 +351,12 @@ end function inet4:network() local hostbits = 32 - self.mask - return inet4.new(lshift(rshift(self.bip, hostbits), hostbits), self.mask) + return new_inet4(lshift(rshift(self.bip, hostbits), hostbits), self.mask) end function inet4:netmask() local hostbits = 32 - self.mask - return inet4.new(replace(0xffffffff, 0, 0, hostbits), 32) + return new_inet4(replace(0xffffffff, 0, 0, hostbits), 32) end function inet4:flip() @@ -236,66 +365,7 @@ function inet4:flip() if mask == 0 then return nil end local hostbits = 32 - mask local flipbit = lshift(1, hostbits) - return inet4.new(bxor(self.bip, flipbit), mask) -end - -local function parse6(ipstr) - local pcs, netmask = ipv6_parser:match(ipstr) - if not pcs then return nil end - if #pcs > 8 then return nil, 'too many pieces' end - local zero_pieces = 8 - #pcs - for i=1,#pcs do - if pcs[i] == '::' then - pcs[i] = 0 - for j=1,#pcs-i do - pcs[i+j+zero_pieces] = pcs[i+j] - end - for j=1,zero_pieces do - pcs[i+j] = 0 - end - end - end - if #pcs > 8 then return nil, 'too many pieces' end - if netmask == '' then - netmask = 128 - elseif netmask > 128 then - return nil, 'invalid netmask' - end - return pcs, netmask -end - -function inet6.new(ip, netmask) - local pcs, err - if type(ip) == 'string' then - pcs, err = parse6(ip) - if pcs == nil then - return nil, err - end - if not netmask then - netmask = err - end - elseif type(ip) == 'table' then - pcs = { ip[1], ip[2], ip[3], ip[4], - ip[5], ip[6], ip[7], ip[8] } - if not netmask then - netmask = 128 - end - else - return nil - end - - local r = setmetatable({ - pcs = pcs, - mask = netmask, - }, inet6) - - -- ensure that the result is balanced - if not r:is_balanced() then - r:balance() - return nil, tostring(r)..' unbalanced' - end - - return r + return new_inet4(bxor(self.bip, flipbit), mask) end -- each ipv6 address is stored as eight pieces @@ -415,7 +485,7 @@ function inet6:cidrstring() end function inet6:clone() - return inet6.new(self.pcs, self.mask) + return new_inet6(self.pcs, self.mask) end function inet6:__eq(other) @@ -433,11 +503,11 @@ function inet6:__eq(other) end function inet6:__div(n) - return inet6.new(self.pcs, n) + return new_inet6(self.pcs, n) end function inet6:__pow(n) - return inet6.new(self.pcs, self.mask + n) + return new_inet6(self.pcs, self.mask + n) end function inet6:__add(n) @@ -466,7 +536,7 @@ function inet6:network() newpcs[i] = lshift(rshift(pcs[i], netbitsleft), netbitsleft) end end - return inet6.new(newpcs, netbits) + return new_inet6(newpcs, netbits) end function inet6:flip() -- cgit v1.2.1