-- ipv4 / 24 = network -- ipv6/56 * 5 = 5 /56 further down local bit32 = require 'bit32' local inet, inet4, inet6 local mt2fam = {} inet = {} inet.__index = inet inet4 = setmetatable({}, inet) inet4.__index = inet4 mt2fam[inet4] = 4 inet6 = setmetatable({}, inet) inet6.__index = inet6 mt2fam[inet6] = 6 local function get_mt(t) if type(t) ~= 'table' then return nil end return getmetatable(t) end local function is_inet4(t) local mt = get_mt(t) return mt == inet4 end local function is_inet6(t) local mt = get_mt(t) return mt == inet6 end local function is_inet(t) local mt = get_mt(t) return mt == inet4 or mt == inet6 end function inet.new(ip, mask) local ipv6 = string.find(ip, ':', 1, true) if ipv6 then return inet6.new(ip, mask) else return inet4.new(ip, mask) end end function inet:__len() local mask = self.mask if mask == nil then return 0 end -- make metatable inspectable return mask end inet4.__len = inet.__len inet6.__len = inet.__len function inet:family() local mt = assert(getmetatable(self)) return assert(mt2fam[mt]) end local lshift = bit32.lshift local rshift = bit32.rshift local band = bit32.band local replace = bit32.replace local bxor = bit32.bxor local ipv4_parser local ipv6_parser do local lpeg = require 'lpeg' local C, Ct = lpeg.C, lpeg.Ct local S, R = lpeg.S, lpeg.R local B, Cc = lpeg.B, lpeg.Cc local digit = R('09') do local dot = S('.') local zero = S('0') local octet0 = B(zero) * Cc('0') local octet1 = R('19') local octet2 = R('19') * digit local octet31 = S('1') * digit * digit local octet32 = R('04') * digit local octet325 = S('5') * R('05') local octet3 = octet31 + (S('2') * (octet32 + octet325)) local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) local ipv4 = octet * dot * octet * dot * octet * dot * octet local mask12 = R('12') * digit local mask3 = S('3') * R('02') local netmask = S('/') * C(mask12 + mask3 + digit) ipv4_parser = ipv4 * (netmask + C('')) * -1 end do local function hextonumber(hex) return tonumber(hex, 16) end local hexdigit = R('09') + R('af') + R('AF') local piece = C(hexdigit * (hexdigit^-3)) / hextonumber local col = S(':') local colcol = C(col * col) local picol = piece * col local colpi = col * piece local full = picol * picol * picol * picol * picol * picol * picol * piece local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*piece)^-1 local netmask = S('/') * C((digit^-3)) / tonumber ipv6_parser = Ct(full + partial) * ((netmask + C(''))^-1) * -1 end end local function parse4(ipstr) local o1, o2, o3, o4, mask = ipv4_parser:match(ipstr) if o1 == nil then return nil end local bip = lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4 return bip, tonumber(mask) end function inet4.new(ip, mask) local bip, ourmask if type(ip) == 'string' then bip, ourmask = parse4(ip) if bip == nil then return nil end elseif type(ip) == 'number' then bip = ip end if mask then if type(mask) == 'number' and mask >= 0 and mask <= 32 then ourmask = mask else error('invalid mask') end else if not ourmask then ourmask = 32 end end return setmetatable({ bip = bip, mask = ourmask, }, inet4) end local function tostr4(self, withmask) -- return human readable local bip, mask = self.bip, self.mask local o1, o2, o3, o4 o1 = band(rshift(bip, 24), 0xff) o2 = band(rshift(bip, 16), 0xff) o3 = band(rshift(bip, 8), 0xff) o4 = band(bip, 0xff) if (mask == nil or mask == 32 or withmask == false) and withmask ~= true then return string.format('%d.%d.%d.%d', o1, o2, o3, o4) else return string.format('%d.%d.%d.%d/%d', o1, o2, o3, o4, mask) end end function inet4:__tostring() return tostr4(self) end function inet4:ipstring() return tostr4(self, false) end function inet4:cidrstring() return tostr4(self, true) end function inet4:__add(n) return inet4.new(self.bip + n, self.mask) end function inet4:__sub(n) return inet4.new(self.bip - n, self.mask) end function inet4:__mul(n) local new = self.bip + (n * math.pow(2, 32 - self.mask)) return inet4.new(new, self.mask) end function inet4:__div(n) return inet4.new(self.bip, n) end function inet4:__pow(n) return inet4.new(self.bip, self.mask + n) end function inet4:__lt(other) if self.mask <= other.mask then return false end local mask = other.mask local selfnet = replace(self.bip, 0, 0, 32-mask) local othernet = replace(other.bip, 0, 0, 32-mask) return selfnet == othernet end function inet4:__le(other) if self.mask < other.mask then return false end local mask = other.mask if mask == 32 then return self.bip == other.bip else local selfnet = replace(self.bip, 0, 0, 32-mask) local othernet = replace(other.bip, 0, 0, 32-mask) return selfnet == othernet end end function inet4:__eq(other) return self.bip == other.bip and self.mask == other.mask end function inet4:network() local hostbits = 32 - self.mask return inet4.new(lshift(rshift(self.bip, hostbits), hostbits), self.mask) end function inet4:netmask() local hostbits = 32 - self.mask return inet4.new(replace(0xffffffff, 0, 0, hostbits), 32) end function inet4:flip() -- find twin by flipping the last network bit local mask = self.mask if mask == 0 then return nil end local hostbits = 32 - mask local flipbit = lshift(1, hostbits) return inet4.new(bxor(self.bip, flipbit), mask) end local function parse6(ipstr) local pcs, netmask = ipv6_parser:match(ipstr) if not pcs then return nil end if #pcs > 8 then return nil, 'too many pieces' end local zero_pieces = 8 - #pcs for i=1,#pcs do if pcs[i] == '::' then pcs[i] = 0 for j=1,#pcs-i do pcs[i+j+zero_pieces] = pcs[i+j] end for j=1,zero_pieces do pcs[i+j] = 0 end end end if #pcs > 8 then return nil, 'too many pieces' end if netmask == '' then netmask = 128 elseif netmask > 128 then return nil, 'invalid netmask' end return pcs, netmask end function inet6.new(ip, netmask) local pcs, err if type(ip) == 'string' then pcs, err = parse6(ip) if pcs == nil then return nil, err end if not netmask then netmask = err end elseif type(ip) == 'table' then pcs = { ip[1], ip[2], ip[3], ip[4], ip[5], ip[6], ip[7], ip[8] } if not netmask then netmask = 128 end else return nil end local r = setmetatable({ pcs = pcs, mask = netmask, }, inet6) -- ensure that the result is balanced if not r:is_balanced() then r:balance() return nil, tostring(r)..' unbalanced' end return r end -- each ipv6 address is stored as eight pieces -- 1111:2222:3333:4444:5555:6666:7777:8888 -- in the table pcs. function inet6:is_balanced() local pcs = self.pcs for i=1,8 do local piece = pcs[i] if piece < 0 or piece > 0xffff then return false end end return true end function inet6:balance(quick) local pcs = self.pcs local i = 8 while i > 1 do if quick and pcs[i] > 0 then break end while pcs[i] < 0 do pcs[i] = pcs[i] + 0x10000 pcs[i-1] = pcs[i-1] - 1 end i = i - 1 end i = 8 while i > 1 do local extra = rshift(pcs[i], 16) if quick and extra == 0 then break end pcs[i] = band(pcs[i], 0xffff) pcs[i-1] = pcs[i-1] + extra i = i - 1 end pcs[1] = band(pcs[1], 0xffff) return self end local function tostr6(self, withmask) -- return human readable local pcs = self.pcs local zeros = {} -- count zero clusters local first_zero = 0 local prev_was_zero = false for i=1,#pcs do if pcs[i] == 0 then if prev_was_zero then zeros[first_zero] = zeros[first_zero] + 1 else first_zero = i zeros[first_zero] = 1 end prev_was_zero = true else prev_was_zero = false end end -- find the largest zero cluster local zeros_begin = nil local zeros_cnt = 0 for begin,cnt in pairs(zeros) do if cnt > zeros_cnt then zeros_begin = begin zeros_cnt = cnt end end -- format ipv6 address local out = '' local i = 1 while i <= 8 do if i == zeros_begin then if i > 1 then out = out .. ':' else out = out .. '::' end i = i + zeros_cnt else local p = pcs[i] local hexdigits = string.format('%x', p) out = out .. hexdigits if i ~= 8 then out = out .. ':' end i = i + 1 end end local mask = self.mask if (mask == nil or mask == 128 or withmask == false) and withmask ~= true then return out else return string.format('%s/%d', out, mask) end end function inet6:__tostring() return tostr6(self) end function inet6:ipstring() return tostr6(self, false) end function inet6:cidrstring() return tostr6(self, true) end function inet6:clone() return inet6.new(self.pcs, self.mask) end function inet6:__eq(other) if self.mask ~= other.mask then return false end local spcs = self.pcs local opcs = other.pcs for i=1,8 do if spcs[i] ~= opcs[i] then return false end end return true end function inet6:__div(n) return inet6.new(self.pcs, n) end function inet6:__pow(n) return inet6.new(self.pcs, self.mask + n) end function inet6:__add(n) local new = self:clone() local pcs = new.pcs pcs[8] = pcs[8] + n new:balance(true) return new end function inet6:__sub(n) return self + (n*-1) end function inet6:network() local netbits = self.mask local pcs = self.pcs local newpcs = { 0, 0, 0, 0, 0, 0, 0, 0 } for i=1,8 do if netbits >= i*16 then newpcs[i] = pcs[i] elseif netbits <= (i-1)*16 then break -- the rest is already zero else local netbitsleft = 16-(netbits-((i-1)*16)) newpcs[i] = lshift(rshift(pcs[i], netbitsleft), netbitsleft) end end return inet6.new(newpcs, netbits) end function inet6:flip() -- find twin by flipping the last network bit local mask = self.mask if mask == 0 then return nil end local block = rshift(mask, 4)+1 local maskbits = band(mask, 0xf) local bitno = 16 - maskbits if bitno == 16 then block = block - 1 bitno = 0 end local flipbit = lshift(1, bitno) local r = self:clone() local val = r.pcs[block] r.pcs[block] = bxor(val, flipbit) --print(mask, block, maskbits, bitno, flipbit, self, r:balance()) return r end function inet6:__mul(n) local new = self:clone() local mask = new.mask local pcs = new.pcs local netbitoverflow = mask % 16 local netbitremainder = (128-mask) % 16 local p = (mask - netbitoverflow) / 16 if netbitremainder ~= 0 then p = p + 1 end local was_negative = false if n < 0 then n = n * -1 was_negative = true end local shiftet = lshift(n, netbitremainder) local high_shift = rshift(shiftet, 16) local low_shift = band(shiftet, 0xffff) --print(p, netbitoverflow, hex(shiftet), hex(high_shift), hex(low_shift)) if was_negative then high_shift = -high_shift low_shift = -low_shift end if p > 2 then pcs[p-1] = pcs[p-1] + high_shift end pcs[p] = pcs[p] + low_shift new:balance() return new end local M = {} local mt = {} function mt.__call(_, ...) return new_inet(...) end M.is4 = is_inet4 M.is6 = is_inet6 M.is = is_inet return setmetatable(M, mt)