diff --git a/earwigbot/wiki/copyvios/markov.py b/earwigbot/wiki/copyvios/markov.py index cf26317..9a4717d 100644 --- a/earwigbot/wiki/copyvios/markov.py +++ b/earwigbot/wiki/copyvios/markov.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from collections import defaultdict from re import sub, UNICODE __all__ = ["EMPTY", "EMPTY_INTERSECTION", "MarkovChain", @@ -34,23 +33,27 @@ class MarkovChain(object): def __init__(self, text): self.text = text - self.chain = defaultdict(lambda: defaultdict(lambda: 0)) - words = sub(r"[^\w\s-]", "", text.lower(), flags=UNICODE).split() + self.chain = self._build() + self.size = self._get_size() + def _build(self): + """Build and return the Markov chain from the input text.""" padding = self.degree - 1 + words = sub(r"[^\w\s-]", "", self.text.lower(), flags=UNICODE).split() words = ([self.START] * padding) + words + ([self.END] * padding) - for i in range(len(words) - self.degree + 1): - last = i + self.degree - 1 - self.chain[tuple(words[i:last])][words[last]] += 1 - self.size = self._get_size() + chain = {} + + for i in xrange(len(words) - self.degree + 1): + phrase = tuple(words[i:i+self.degree]) + if phrase in chain: + chain[phrase] += 1 + else: + chain[phrase] = 1 + return chain def _get_size(self): """Return the size of the Markov chain: the total number of nodes.""" - size = 0 - for node in self.chain.itervalues(): - for hits in node.itervalues(): - size += hits - return size + return sum(self.chain.itervalues()) def __repr__(self): """Return the canonical string representation of the MarkovChain.""" @@ -65,20 +68,21 @@ class MarkovChainIntersection(MarkovChain): """Implements the intersection of two chains (i.e., their shared nodes).""" def __init__(self, mc1, mc2): - self.chain = defaultdict(lambda: defaultdict(lambda: 0)) self.mc1, self.mc2 = mc1, mc2 - c1 = mc1.chain - c2 = mc2.chain - - for word, nodes1 in c1.iteritems(): - if word in c2: - nodes2 = c2[word] - for node, count1 in nodes1.iteritems(): - if node in nodes2: - count2 = nodes2[node] - self.chain[word][node] = min(count1, count2) + self.chain = self._build() self.size = self._get_size() + def _build(self): + """Build and return the Markov chain from the input chains.""" + c1 = self.mc1.chain + c2 = self.mc2.chain + chain = {} + + for phrase in c1: + if phrase in c2: + chain[phrase] = min(c1[phrase], c2[phrase]) + return chain + def __repr__(self): """Return the canonical string representation of the intersection.""" res = "MarkovChainIntersection(mc1={0!r}, mc2={1!r})"