local common = require 'inet.common' local bitops = require 'inet.bitops' local format = string.format local floor = math.floor local min = math.min local max = math.max local pow = math.pow local insert = table.insert local lshift = bitops.lshift local rshift = bitops.rshift local band = bitops.band local extract = bitops.extract local replace = bitops.replace local bxor = bitops.bxor if not pow then function pow(x, y) return x ^ y end end local get_mt = common.get_mt local mixed_networks local mt2fam = {} local inet = {} inet.__index = inet local inet4 = setmetatable({}, inet) inet4.__index = inet4 mt2fam[inet4] = 4 local inet6 = setmetatable({}, inet) inet6.__index = inet6 mt2fam[inet6] = 6 local function is_inet4(t) local mt = get_mt(t) return mt == inet4 end local function is_inet6(t) local mt = get_mt(t) return mt == inet6 end local function is_inet(t) local mt = get_mt(t) return mt == inet4 or mt == inet6 end function inet:__len() local mask = self.mask if mask == nil then return 0 end -- make metatable inspectable return mask end inet4.__len = inet.__len inet6.__len = inet.__len function inet:subnets(n) if type(n) ~= 'number' then return nil, 'n must be a number' end local hostmask = is_inet6(self) and 128 or 32 if n < 0 or n > hostmask then return nil, 'invalid mask given' end local mask = self.mask local bits = n - mask local subnets = pow(2, bits) return subnets end function inet:family() local mt = assert(getmetatable(self)) return assert(mt2fam[mt]) end local ipv4_parser local ipv6_parser do local lpeg = require 'lpeg' local C, Ct = lpeg.C, lpeg.Ct local S, R = lpeg.S, lpeg.R local B, Cc = lpeg.B, lpeg.Cc local digit = R('09') local ipv4addr do local dot = S('.') local zero = S('0') local octet0 = B(zero) * Cc('0') local octet1 = R('19') local octet2 = R('19') * digit local octet31 = S('1') * digit * digit local octet32 = R('04') * digit local octet325 = S('5') * R('05') local octet3 = octet31 + (S('2') * (octet32 + octet325)) local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) / tonumber ipv4addr = octet * dot * octet * dot * octet * dot * octet local mask12 = R('12') * digit local mask3 = S('3') * R('02') local netmask = S('/') * C(mask12 + mask3 + digit) ipv4_parser = ipv4addr * (netmask + Cc()) * -1 end do local function hextonumber(hex) return tonumber(hex, 16) end local hexdigit = R('09') + R('af') + R('AF') local piece = C(hexdigit * (hexdigit^-3)) / hextonumber local col = S(':') local colcol = C(col * col) local picol = piece * col local colpi = col * piece local ipv4embed = ipv4addr / function(a,b,c,d) return lshift(a, 8) + b, lshift(c, 8) + d end local last32bits = (ipv4embed + (picol * piece)) local full = picol * picol * picol * picol * picol * picol * last32bits local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*(ipv4embed+piece))^-1 local netmask = S('/') * C((digit^-3)) / tonumber local pieces = full + partial ipv6_parser = Ct(pieces) * ((netmask + Cc())^-1) * -1 end end local function build_bip(o1, o2, o3, o4) return lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4 end local function inet4_from_string(ipstr) local o1, o2, o3, o4, mask = ipv4_parser:match(ipstr) if not o1 then return nil, 'parse error' end local bip = build_bip(o1, o2, o3, o4) return bip, tonumber(mask) end local function inet4_from_number(bip) if bip < 0 or bip > 0xffffffff then return nil, 'out of range' end return bip end local function inet4_from_table(t) if #t ~= 4 then return nil, 'invalid length' end for i=1,4 do local v = t[i] if type(v) ~= 'number' then return nil, 'invalid number' end if v < 0 or v > 255 then return nil, 'octet out of range' end end return build_bip(t[1], t[2], t[3], t[4]) end local inet4_constructors = { string = inet4_from_string, number = inet4_from_number, table = inet4_from_table, } local function inet6_from_string(ipstr) local pcs, netmask = ipv6_parser:match(ipstr) if not pcs then return nil, 'parse error' end if #pcs > 8 then return nil, 'too many pieces' end local zero_pieces = 8 - #pcs for i=1,#pcs do if pcs[i] == '::' then pcs[i] = 0 for j=#pcs,i,-1 do pcs[j+zero_pieces] = pcs[j] end for j=1,zero_pieces do pcs[i+j] = 0 end end end if #pcs > 8 then return nil, 'too many pieces' end if netmask ~= nil and netmask > 128 then return nil, 'invalid netmask' end return pcs, netmask end local function inet6_from_table(t) if #t ~= 8 then return nil, 'invalid length' end for i=1,8 do local v = t[i] if type(v) ~= 'number' then return nil, 'invalid number' end if v < 0 or v > 0xffff then return nil, 'piece out of range' end end return { t[1], t[2], t[3], t[4], t[5], t[6], t[7], t[8] } end local inet6_constructors = { string = inet6_from_string, table = inet6_from_table, } local function decide_mask(from_ip, override, high) local newmask = from_ip if override then if type(override) == 'number' and override >= 0 and override <= high then if from_ip ~= nil then return nil, 'multiple masks supplied' end newmask = override else return nil, 'invalid mask' end else if not newmask then newmask = high end end return newmask end local function generic_new(constructors, high, ip, mask) local type_ip = type(ip) local constructor = constructors[type_ip] if not constructor then return nil, 'invalid ip argument' end local iir, ourmask = constructor(ip) if not iir then return nil, ourmask end local outmask, err = decide_mask(ourmask, mask, high) if not outmask then return nil, err end return iir, outmask end local function new_inet4(ip, mask) local bip, outmask = generic_new(inet4_constructors, 32, ip, mask) if not bip then return nil, outmask end return setmetatable({ bip = bip, mask = outmask, }, inet4) end local function new_inet6(ip, mask) local pcs, outmask = generic_new(inet6_constructors, 128, ip, mask) if not pcs then return nil, outmask end local r = setmetatable({ pcs = pcs, mask = outmask, }, inet6) -- ensure that the result is balanced if not r:is_balanced() then r:balance() return nil, tostring(r)..' unbalanced' end return r end local function new_inet(ip, mask) local is_ipv6 local type_ip = type(ip) if type_ip == 'string' then is_ipv6 = string.find(ip, ':', 1, true) elseif type_ip == 'number' then is_ipv6 = false elseif is_inet4(ip) then mask = mask or #ip ip = ip.bip is_ipv6 = false elseif is_inet6(ip) then mask = mask or #ip ip = ip.pcs is_ipv6 = true elseif type_ip == 'table' then local n = #ip if n == 8 then is_ipv6 = true elseif n == 4 then is_ipv6 = false else return nil, 'invalid table' end else return nil, 'invalid ip type' end if is_ipv6 then return new_inet6(ip, mask) else return new_inet4(ip, mask) end end local function tostr4(self, withmask) -- return human readable local bip, mask = self.bip, self.mask local o1, o2, o3, o4 o1 = band(rshift(bip, 24), 0xff) o2 = band(rshift(bip, 16), 0xff) o3 = band(rshift(bip, 8), 0xff) o4 = band(bip, 0xff) if (mask == nil or mask == 32 or withmask == false) and withmask ~= true then return string.format('%d.%d.%d.%d', o1, o2, o3, o4) else return string.format('%d.%d.%d.%d/%d', o1, o2, o3, o4, mask) end end function inet4:__tostring() return tostr4(self) end function inet4:ipstring() return tostr4(self, false) end function inet4:cidrstring() return tostr4(self, true) end function inet4:__add(n) local type_n = type(n) if type_n == 'number' then return new_inet4(self.bip + n, self.mask) elseif type_n == 'table' and is_inet6(n) then return inet6.__add(n, self) else return nil, 'invalid argument' end end function inet4:__sub(n) local type_n = type(n) if type_n == 'number' then return new_inet4(self.bip - n, self.mask) elseif type_n == 'table' and is_inet4(n) then return self.bip - n.bip else return nil, 'invalid argument' end end function inet4:__mul(n) local new = self.bip + (n * pow(2, 32 - self.mask)) return new_inet4(new, self.mask) end function inet4:__div(n) return new_inet4(self.bip, n) end function inet4:__pow(n) return new_inet4(self.bip, self.mask + n) end function inet4:clone() return new_inet4(self.bip, self.mask) end function inet4:contains(other) if self.mask >= other.mask then return false end local mask = self.mask -- make test local self_netbits = replace(self.bip, 0, 0, 32-mask) local other_netbits = replace(other.bip, 0, 0, 32-mask) return self_netbits == other_netbits end function inet4:__lt(other) if self.bip == other.bip then return self.mask < other.mask end return self.bip < other.bip end function inet4:__le(other) if self.mask < other.mask then return false end local mask = other.mask if mask == 32 then return self.bip == other.bip else local selfnet = replace(self.bip, 0, 0, 32-mask) local othernet = replace(other.bip, 0, 0, 32-mask) return selfnet == othernet end end function inet4:__eq(other) return self.bip == other.bip and self.mask == other.mask end function inet4:network() local hostbits = 32 - self.mask return new_inet4(lshift(rshift(self.bip, hostbits), hostbits), self.mask) end function inet4:netmask() local hostbits = 32 - self.mask return new_inet4(replace(0xffffffff, 0, 0, hostbits), 32) end function inet4:hostmask() local mask = self.mask local hostbits = 32 - mask return new_inet4(replace(0xffffffff, 0, hostbits, mask), 32) end function inet4:flip() -- find twin by flipping the last network bit local mask = self.mask if mask == 0 then return nil end local hostbits = 32 - mask local flipbit = lshift(1, hostbits) return new_inet4(bxor(self.bip, flipbit), mask) end function inet4:bits(n) if type(n) ~= 'number' then return nil, 'n must be a number' end if n < 1 or n > 32 or 32 % n ~= 0 then return nil, 'invalid value for n' end local t = {} local bip = self.bip for i=32-n,0,-n do insert(t, extract(bip, i, n)) end return t end -- each ipv6 address is stored as eight pieces -- 1111:2222:3333:4444:5555:6666:7777:8888 -- in the table pcs. function inet6:is_balanced() local pcs = self.pcs for i=1,8 do local piece = pcs[i] if piece < 0 or piece > 0xffff then return false end end return true end local function do_balance(pcs, quick) local i = 8 while i > 1 do if quick and pcs[i] > 0 then break end while pcs[i] < 0 do pcs[i] = pcs[i] + 0x10000 pcs[i-1] = pcs[i-1] - 1 end i = i - 1 end i = 8 while i > 1 do local extra = rshift(pcs[i], 16) if quick and extra == 0 then break end pcs[i] = band(pcs[i], 0xffff) pcs[i-1] = pcs[i-1] + extra i = i - 1 end if pcs[1] < 0 or pcs[1] > 0xffff then return nil, 'out of range' end return true end function inet6:balance(quick) local ok, err = do_balance(self.pcs, quick) if not ok then return nil, err end return self end local function tohex(n) if n == nil then return nil end return format('%x', n) end local function tostr6(self, withmask, embeddedipv4) -- return human readable local pcs = self.pcs local zeros = {} if embeddedipv4 == nil then embeddedipv4 = mixed_networks:contains(self) end local ipv6pieces = 8 if embeddedipv4 then ipv6pieces = 6 end -- count zero clusters local first_zero = 0 local prev_was_zero = false for i=1,ipv6pieces do if pcs[i] == 0 then if prev_was_zero then zeros[first_zero] = zeros[first_zero] + 1 else first_zero = i zeros[first_zero] = 1 end prev_was_zero = true else prev_was_zero = false end end -- find the first largest zero cluster local zeros_begin = nil local zeros_cnt = 1 for begin=1,ipv6pieces do local cnt = zeros[begin] or 0 if cnt > zeros_cnt then zeros_begin = begin zeros_cnt = cnt end end -- format ipv6 address local out = '' local i = 1 while i <= ipv6pieces do if i == zeros_begin then if i > 1 then out = out .. ':' else out = out .. '::' end i = i + zeros_cnt else local p = pcs[i] local hexdigits = tohex(p) out = out .. hexdigits if i ~= 8 then out = out .. ':' end i = i + 1 end end if embeddedipv4 then out = out .. new_inet4(lshift(pcs[7], 16) + pcs[8]):ipstring() end local mask = self.mask if (mask == nil or mask == 128 or withmask == false) and withmask ~= true then return out else return string.format('%s/%d', out, mask) end end function inet6:__tostring() return tostr6(self) end function inet6:ipstring() return tostr6(self, false) end function inet6:ipstring4() return tostr6(self, false, true) end function inet6:ipstring6() return tostr6(self, false, false) end function inet6:cidrstring() return tostr6(self, true) end function inet6:clone() return new_inet6(self.pcs, self.mask) end function inet6:contains(other) -- self contains other local mask = self.mask if mask > other.mask then return false end local snet = self:network() local foo, err = other:__div(mask) if not foo then print(err) end local onet = (other / mask):network() return snet == onet end function inet6:__lt(other) -- self < other local spcs = self.pcs local opcs = other.pcs for i=1,8 do if spcs[i] < opcs[i] then return true end if spcs[i] > opcs[i] then return false end end return self.mask < other.mask end function inet6:__le(other) -- self <= other local spcs = self.pcs local opcs = other.pcs for i=1,8 do if spcs[i] < opcs[i] then return true end if spcs[i] > opcs[i] then return false end end return self.mask <= other.mask end function inet6:__eq(other) if self.mask ~= other.mask then return false end local spcs = self.pcs local opcs = other.pcs for i=1,8 do if spcs[i] ~= opcs[i] then return false end end return true end function inet6:__div(n) return new_inet6(self.pcs, n) end function inet6:__pow(n) return new_inet6(self.pcs, self.mask + n) end function inet6:__add(n) local new = self:clone() local pcs = new.pcs local type_n = type(n) if type_n == 'number' then pcs[8] = pcs[8] + n elseif type_n == 'table' and is_inet4(n) then if #new ~= 96 then return nil, 'inet6 must be a /96' end if #n ~= 32 then return nil, 'inet4 must be a /32' end if not mixed_networks:contains(new) then return nil, 'inet6 is not a mixed notation network' end if new ~= new:network() then return nil, 'inet6 must be a network address' end local bip = n.bip pcs[7] = band(rshift(bip, 16), 0xffff) pcs[8] = band(bip, 0xffff) new.mask = 128 else return nil, 'invalid argument' end return new:balance(true) end function inet6:__sub(n) local type_n = type(n) if type_n == 'number' then return self + (n*-1) elseif type_n == 'table' and is_inet6(n) then local spcs = self.pcs local npcs = n.pcs local dpcs = {} for i=1,8 do dpcs[i] = spcs[i] - npcs[i] end local ok, err = do_balance(dpcs) if not ok then return nil, err, dpcs end local ret = 0 for i=1,8 do local v = dpcs[i] if (i < 7 and v > 0) or v < 0 or v > 0xffff then return nil, 'result is out of range', dpcs end local bits = (8 - i) * 16 ret = ret + lshift(band(v, 0xffff), bits) end if #self == 128 and #n == 96 and mixed_networks:contains(self) and mixed_networks:contains(n) then return new_inet4(ret) else return ret end else return nil, 'invalid argument' end end function inet6:network() local netbits = self.mask local pcs = self.pcs local newpcs = { 0, 0, 0, 0, 0, 0, 0, 0 } for i=1,8 do if netbits >= i*16 then newpcs[i] = pcs[i] elseif netbits <= (i-1)*16 then break -- the rest is already zero else local netbitsleft = 16-(netbits-((i-1)*16)) newpcs[i] = lshift(rshift(pcs[i], netbitsleft), netbitsleft) end end return new_inet6(newpcs, netbits) end local function build_inet6_mask(z1, o1, z2) assert(z1 + o1 + z2 == 128) local pcs = { 0, 0, 0, 0, 0, 0, 0, 0 } local b, l = z1, o1 if l > 0 then local e = b + l - 1 local bpcs = floor(b / 16) + 1 local epcs = floor(e / 16) + 1 for j=bpcs,epcs do local o = (j-1) * 16 local bo = max(0,b-o) local width = min(15,e-o)+1 - bo local fbit = 16 - width - bo local v = replace(pcs[j], 0xffff, fbit, width) pcs[j] = v end end return new_inet6(pcs) end function inet6:netmask() local mask = self.mask return build_inet6_mask(0, mask, 128 - mask) end function inet6:hostmask() local mask = self.mask return build_inet6_mask(mask, 128 - mask, 0) end function inet6:flip() -- find twin by flipping the last network bit local mask = self.mask if mask == 0 then return nil end local block = rshift(mask, 4)+1 local maskbits = band(mask, 0xf) local bitno = 16 - maskbits if bitno == 16 then block = block - 1 bitno = 0 end local flipbit = lshift(1, bitno) local r = self:clone() local val = r.pcs[block] r.pcs[block] = bxor(val, flipbit) --print(mask, block, maskbits, bitno, flipbit, self, r:balance()) return r end function inet6:__mul(n) local new = self:clone() local mask = new.mask if mask == 0 then return nil, 'unable to perform operation' end local pcs = new.pcs local netbitoverflow = mask % 16 local netbitremainder = (128-mask) % 16 local p = (mask - netbitoverflow) / 16 if netbitremainder ~= 0 then p = p + 1 end local was_negative = false if n < 0 then n = n * -1 was_negative = true end local shiftet = lshift(n, netbitremainder) local high_shift = rshift(shiftet, 16) local low_shift = band(shiftet, 0xffff) --print(p, netbitoverflow, tohex(shiftet), tohex(high_shift), tohex(low_shift)) if was_negative then high_shift = -high_shift low_shift = -low_shift end if p > 2 then pcs[p-1] = pcs[p-1] + high_shift end pcs[p] = pcs[p] + low_shift return new:balance() end function inet6:bits(n) if type(n) ~= 'number' then return nil, 'n must be a number' end if n < 1 or n > 32 or 128 % n ~= 0 then return nil, 'invalid value for n' end local t = {} local pcs = self.pcs if n == 32 then for i=1,8,2 do insert(t, lshift(pcs[i], 16) + pcs[i+1]) end else for i=1,8 do local p = pcs[i] for j=16-n,0,-n do insert(t, extract(p, j, n)) end end end return t end local M = {} function M.set_mixed_networks(mixed_set) mixed_networks = mixed_set end M.is_inet4 = is_inet4 M.is_inet6 = is_inet6 M.is_inet = is_inet M.new_inet = new_inet return M