require('Module:Lua class')
local libraryUtil = require('libraryUtil')
local TableTools = require('Module:TableTools')
-- Note: hash collisions are not handled yet
local frozenset = class('frozenset', {
__init = function (self, args)
local elements = {}
if #args == 0 then -- for performance
self._elements = elements
return
elseif #args == 1 then
local arg = args[1]
if type(arg) == 'string' then
for i = 1, mw.ustring.len(arg) do
elements[mw.ustring.sub(arg, i,i)] = true
end
self._elements = elements
return
elseif pcall(pairs, arg) or pcall(ipairs, arg) then
args = arg
end
end
if TableTools.isArrayLike(args) then
for i, v in ipairs(args) do
if (pcall(pairs, v) or pcall(ipairs, v)) and not (isinstance(v) and v.hash) then
error(("TypeError: invalid element #%d type (got %s, a mutable collection)"):format(i, type(v)), 3)
end
self._set(elements, v)
end
else
for k in pairs(args) do
if (pcall(pairs, k) or pcall(ipairs, k)) and not (isinstance(k) and k.hash) then
error(("TypeError: invalid element type (got %s, a mutable collection)"):format(type(k)), 3)
end
self._set(elements, k)
end
end
self._elements = elements
end,
__pairs = function (self)
local k, v
local function iterator(elements)
k, v = next(elements, k)
if v == true then
return k
else
return v -- nil at the end
end
end
return iterator, self._elements
end,
__ipairs = function (self)
error("IterationError: a set is unordered, use 'pairs' instead", 2)
end,
_get = function (elements, elem)
if isinstance(elem) and elem.hash then
return elements[elem.hash()]
else
return elements[elem]
end
end,
_set = function (elements, elem)
if isinstance(elem) and elem.hash then
elements[elem.hash()] = elem -- otherwise different objects with the same content would duplicate
else
elements[elem] = true
end
end,
_hash = function (self)
if self.__hash then
return self.__hash
end
-- frozensets with the same elements/keys (meaning equal) may have a different order, so 'order' them before hashing
local ordered_keys = TableTools.keysToList(self._elements)
-- convert keys to strings for table.concat; note that information will be lost for functions
for i, key in ipairs(ordered_keys) do
if type(key) == 'string' then
ordered_keys[i] = "'" .. key .. "'"
else
ordered_keys[i] = tostring(key)
end
end
local str = '{' .. table.concat(ordered_keys, ',') .. '}' -- wrap in {} to differentiate from tuple
self.__hash = tonumber('0x' .. mw.hash.hashValue('fnv1a32', str))
return self.__hash
end,
__tostring = function (self)
local string_elems = {}
for elem in pairs(self) do
if type(elem) == 'string' then
string_elems[#string_elems+1] = "'" .. elem .. "'"
else
string_elems[#string_elems+1] = tostring(elem)
end
end
local str = '{' .. table.concat(string_elems, ', ') .. '}'
return str
end,
len = function (self)
return TableTools.size(self._elements)
end,
has = function (self, elem)
if isinstance(elem, 'set') then
elem = frozenset{elem}
elseif (pcall(pairs, elem) or pcall(ipairs, elem)) and not (isinstance(elem) and elem.hash) then
error(("TypeError: invalid element type (got %s, a mutable collection)"):format(type(elem)), 2)
end
return self._get(self._elements, elem) and true or false
end,
isdisjoint = function (self, other)
libraryUtil.checkTypeMulti('isdisjoint', 1, other, {'set', 'frozenset'})
for elem in pairs(other) do
if self._get(self._elements, elem) then
return false
end
end
return true
end,
issubset = function (self, other)
return self <= frozenset{other}
end,
__le = function (a, b)
for elem in pairs(a) do
if not b._get(b._elements, elem) then
return false
end
end
return true
end,
__lt = function (a, b)
return a <= b and a.len() < b.len() -- is calculating a's length during its traversal in __le faster?
end,
issuperset = function (self, other)
return self >= frozenset{other}
end,
union = function (self, ...)
local sum = set{self}
sum.update(...)
return sum
end,
__add = function (a, b)
local elements = {}
for elem in pairs(a) do
a._set(elements, elem)
end
for elem in pairs(b) do
b._set(elements, elem)
end
return a.__class{elements}
end,
intersection = function (self, ...)
local product = set{self}
product.intersection_update(...)
return product
end,
__mul = function (a, b)
local elements = {}
for elem in pairs(a) do
if b._get(b._elements, elem) then
b._set(elements, elem)
end
end
return a.__class{elements}
end,
difference = function (self, ...)
local difference = set{self}
difference.difference_update(...)
return difference
end,
__sub = function (a, b)
local elements = {}
for elem in pairs(a) do
if not b._get(b._elements, elem) then
b._set(elements, elem)
end
end
return a.__class{elements}
end,
symmetric_difference = function (self, other)
return self ^ frozenset{other}
end,
__pow = function (a, b)
local elements = {}
for elem in pairs(a) do
if not b._get(b._elements, elem) then
b._set(elements, elem)
end
end
for elem in pairs(b) do
if not a._get(a._elements, elem) then
a._set(elements, elem)
end
end
return a.__class{elements}
end,
copy = function (self)
return self.__class{self}
end,
__eq = function (a, b)
return a <= b and a >= b
end,
__staticmethods = {'_get', '_set'},
__protected = {'_get', '_set'}
})
local set = class('set', frozenset, {
_del = function (elements, elem)
if isinstance(elem) and elem.hash then
elements[elem.hash()] = nil
else
elements[elem] = nil
end
end,
update = function (self, ...)
local others, other = {...}, nil
for i = 1, select('#', ...) do
other = frozenset{others[i]}
for elem in pairs(other) do
self._set(self._elements, elem)
end
end
end,
intersection_update = function (self, ...)
local others, other = {...}, nil
for i = 1, select('#', ...) do
other = frozenset{others[i]}
for elem in pairs(self) do
if not other.has(elem) then -- probably faster than iterating through (likely longer) other to access self._get
self._del(self._elements, elem)
end
end
end
end,
difference_update = function (self, ...)
local others, other = {...}, nil
for i = 1, select('#', ...) do
other = frozenset{others[i]}
for elem in pairs(self) do
if other.has(elem) then
self._del(self._elements, elem)
end
end
end
end,
symmetric_difference_update = function (self, other)
other = frozenset{other}
for elem in pairs(self) do
if other.has(elem) then
self._del(self._elements, elem)
end
end
for elem in pairs(other) do
if not self._get(self._elements, elem) then
self._set(self._elements, elem)
end
end
end,
add = function (self, elem)
if (pcall(pairs, elem) or pcall(ipairs, elem)) and not (isinstance(elem) and elem.hash) then
error(("TypeError: invalid element type (got %s, a mutable collection)"):format(type(elem)), 2)
end
self._set(self._elements, elem)
end,
remove = function (self, elem)
if isinstance(elem, 'set') then
elem = frozenset{elem}
elseif (pcall(pairs, elem) or pcall(ipairs, elem)) and not (isinstance(elem) and elem.hash) then
error(("TypeError: invalid element type (got %s, a mutable collection)"):format(type(elem)), 2)
end
if not self._get(self._elements, elem) then -- cannot use self.has since error level would be incorrect
error(("KeyError: %s"):format(tostring(elem)), 2)
end
self._del(self._elements, elem)
end,
discard = function (self, elem)
if isinstance(elem, 'set') then
elem = frozenset{elem}
elseif (pcall(pairs, elem) or pcall(ipairs, elem)) and not (isinstance(elem) and elem.hash) then
error(("TypeError: invalid element type (got %s, a mutable collection)"):format(type(elem)), 2)
end
self._del(self._elements, elem)
end,
pop = function (self)
local k, v = next(self._elements)
if k == nil then
error("KeyError: pop from an empty set", 2)
end
if v == true then
self._del(self._elements, k)
return k
else
self._del(self._elements, v)
return v
end
end,
clear = function (self)
self._elements = {}
end,
__staticmethods = {'_del'},
__protected = {'_del'}
})
return {frozenset, set}