#!/usr/bin/env python

# Typogenetics simulator by David Fifield <david@bamsoftware.com>
# http://www.bamsoftware.com/hacks/geb/index.html#typo
# This program is in the public domain.

import getopt
import sys
import time

try:
    import psyco
    psyco.full()
except ImportError:
    pass

# Control with the -n option.
max_depth = 2
# Control with the --test option.
one_shot = False
# Control with the --verbose option.
verbose = False
start_string = ""

def is_purine(base):
    return base == "A" or base == "G"

def is_pyrimidine(base):
    return base == "C" or base == "T"

COMPLEMENT_MAP = { "A": "T", "T": "A", "C": "G", "G": "C" }

def complement(base):
    return COMPLEMENT_MAP[base]

class Unit(object):
    """A link in a strand."""
    def __init__(self, base = None):
        self.base = base
        self.left = None
        self.right = None
        self.other = None

    def head(self):
        p = self
        while p.left is not None:
            p = p.left
        return p

    def find_right(self, base):
        for p in self:
            if p.base == base:
                return p
        return None

    def __iter__(self):
        p = self
        while p is not None:
            yield(p)
            p = p.right

class Pair(Unit):
    def __init__(self, base = None, comp = False):
        Unit.__init__(self, base)
        other_base = comp and complement(base) or None
        other = Unit(other_base)
        self.other = other
        other.other = self

    def insert_right(self, unit):
        if unit is None:
            return self
        self.left = unit
        self.right = unit.right
        self.other.left = unit.other.left
        self.other.right = unit.other
        if unit.right is not None:
            unit.right.left = self
            unit.right.other.right = self.other
        unit.right = self
        unit.other.left = self.other
        return self

def string_to_strand(string):
    pair = None
    for base in string:
        pair = Pair(base).insert_right(pair)
    return pair.head()

def strand_to_string(strand):
    """Convert one half of a strand to a string with spaces in place of missing
    bases."""
    if strand is None:
        return ""
    s = []
    for unit in strand.head():
        s.append(unit.base or " ")
    return "".join(s)

def strand_split(strand):
    """Return all the discrete strings in both halves of the strand."""
    return strand_to_string(strand).split() + strand_to_string(strand.other).split()

# For the execute method of each op, self is an Enzyme object.

STRAIGHT = 0
LEFT = +1
RIGHT = -1

class op_cut(object):
    turn = STRAIGHT
    acid = "cut"
    @staticmethod
    def execute(self):
        cut = self.unit.right
        self.unit.right = None
        self.unit.other.left = None
        self.strand = self.unit
        if cut is not None:
            cut.left = None
            cut.other.right = None
            self.cuts.append(cut)

class op_del(object):
    turn = STRAIGHT
    acid = "del"
    @staticmethod
    def execute(self):
        self.unit.base = None
        self.unit = self.unit.right
        if self.copy_mode and self.unit is not None and self.unit.base is not None:
            self.unit.other.base = complement(self.unit.base)

class op_swi(object):
    turn = RIGHT
    acid = "swi"
    @staticmethod
    def execute(self):
        self.unit = self.unit.other

class op_mvr(object):
    turn = STRAIGHT
    acid = "mvr"
    @staticmethod
    def execute(self):
        self.unit = self.unit.right
        if self.copy_mode and self.unit is not None and self.unit.base is not None:
            self.unit.other.base = complement(self.unit.base)

class op_mvl(object):
    turn = STRAIGHT
    acid = "mvl"
    @staticmethod
    def execute(self):
        self.unit = self.unit.left
        if self.copy_mode and self.unit is not None and self.unit.base is not None:
            self.unit.other.base = complement(self.unit.base)

class op_cop(object):
    turn = RIGHT
    acid = "cop"
    @staticmethod
    def execute(self):
        self.copy_mode = True
        if self.unit.base is not None:
            self.unit.other.base = complement(self.unit.base)

class op_off(object):
    turn = LEFT
    acid = "off"
    @staticmethod
    def execute(self):
        self.copy_mode = False

class op_ina(object):
    turn = STRAIGHT
    acid = "ina"
    @staticmethod
    def execute(self):
        new = Pair("A", self.copy_mode)
        self.unit = new.insert_right(self.unit)

class op_inc(object):
    turn = RIGHT
    acid = "inc"
    @staticmethod
    def execute(self):
        new = Pair("C", self.copy_mode)
        self.unit = new.insert_right(self.unit)

class op_ing(object):
    turn = RIGHT
    acid = "ing"
    @staticmethod
    def execute(self):
        new = Pair("G", self.copy_mode)
        self.unit = new.insert_right(self.unit)

class op_int(object):
    turn = LEFT
    acid = "int"
    @staticmethod
    def execute(self):
        new = Pair("T", self.copy_mode)
        self.unit = new.insert_right(self.unit)

class op_rpy(object):
    turn = RIGHT
    acid = "rpy"
    @staticmethod
    def execute(self):
        self.unit = self.unit.right
        while self.unit is not None and self.unit.base is not None:
            if self.copy_mode:
                self.unit.other.base = complement(self.unit.base)
            if is_pyrimidine(self.unit.base):
                break
            self.unit = self.unit.right

class op_rpu(object):
    turn = LEFT
    acid = "rpu"
    @staticmethod
    def execute(self):
        self.unit = self.unit.right
        while self.unit is not None and self.unit.base is not None:
            if self.copy_mode:
                self.unit.other.base = complement(self.unit.base)
            if is_purine(self.unit.base):
                break
            self.unit = self.unit.right

class op_lpy(object):
    turn = RIGHT
    acid = "lpy"
    @staticmethod
    def execute(self):
        self.unit = self.unit.left
        while self.unit is not None and self.unit.base is not None:
            if self.copy_mode:
                self.unit.other.base = complement(self.unit.base)
            if is_pyrimidine(self.unit.base):
                break
            self.unit = self.unit.left

class op_lpu(object):
    turn = LEFT
    acid = "lpu"
    @staticmethod
    def execute(self):
        self.unit = self.unit.left
        while self.unit is not None and self.unit.base is not None:
            if self.copy_mode:
                self.unit.other.base = complement(self.unit.base)
            if is_purine(self.unit.base):
                break
            self.unit = self.unit.left

class Enzyme(object):
    BINDING_MAP = [ "A", "C", "T", "G" ]

    def __init__(self):
        self.ops = []
        self.binding_index = 0
        self.copy_mode = False
        self.cuts = []
        self.strand = None
        self.unit = None

    def append(self, op):
        if len(self.ops) > 1:
            self.binding_index += self.ops[-1].turn
            self.binding_index %= len(Enzyme.BINDING_MAP)
        self.ops.append(op)

    def translate(self, strand, start):
        self.strand = start
        self.unit = start
        self.cuts = []

        self.copy_mode = False
        for op in self.ops:
            if self.unit is None or self.unit.base is None:
                break
            op.execute(self)

        result = []
        for strand in self.cuts + [self.strand]:
            result.extend(strand_split(strand))
        return result

    def __cmp__(self, other):
        return cmp(self.ops, other.ops)

    def __len__(self):
        return len(self.ops)

    def __str__(self):
        return "-".join([op.acid for op in self.ops]) + "(%s)" % self.binding_pref

    def __repr__(self):
        return str(self)

    binding_pref = property(lambda self: Enzyme.BINDING_MAP[self.binding_index])

RIBOSOME_MAP = {
    "AA": None,
    "AC": op_cut,
    "AG": op_del,
    "AT": op_swi,
    "CA": op_mvr,
    "CC": op_mvl,
    "CG": op_cop,
    "CT": op_off,
    "GA": op_ina,
    "GC": op_inc,
    "GG": op_ing,
    "GT": op_int,
    "TA": op_rpy,
    "TC": op_rpu,
    "TG": op_lpy,
    "TT": op_lpu,
}

def ribosome(string):
    duplets = []
    for i in range(0, len(string) - 1, 2):
        duplets.append(string[i:i + 2])

    enzymes = []
    enzyme = Enzyme()
    for duplet in duplets:
        op = RIBOSOME_MAP[duplet]
        if op is None:
            if len(enzyme) > 0:
                enzymes.append(enzyme)
            enzyme = Enzyme()
        else:
            enzyme.append(op)
    if len(enzyme) > 0:
        enzymes.append(enzyme)

    return enzymes

def strings_of_size(n):
    if n == 0:
        yield ""
    else:
        for base in ("A", "C", "G", "T"):
            for s in strings_of_size(n - 1):
                yield base + s

def strings_gen(start = ""):
    BASES = ["A", "C", "G", "T"]
    yield start
    s = list(start)
    s.reverse()
    while True:
        for i in range(len(s)):
            if s[i] == "A":
                s[i] = "C"
                break
            elif s[i] == "C":
                s[i] = "G"
                break
            elif s[i] == "G":
                s[i] = "T"
                break
            elif s[i] == "T":
                s[i] = "A"
        else:
            s.append("A")
        yield "".join(reversed(s))

def binding_site(enzyme, strand, n):
    u = strand.find_right(enzyme.binding_pref)
    while u and n > 0:
        u = u.right.find_right(enzyme.binding_pref)
        n -= 1
    return u

def enzyme_string_pairs(enzymes, strings, partial = []):
    if not enzymes or not strings:
        yield partial, strings
        return
    prev = None
    for i in range(len(strings)):
        if strings[i] == prev:
            continue
        prev = strings[i]
        for j in range(strings[i].count(enzymes[0].binding_pref)):
            pair = (enzymes[0], strings[i], j)
            for p, s in enzyme_string_pairs(enzymes[1:], strings[:i] + strings[i + 1:], partial + [pair]):
                yield p, s
    for p, s in enzyme_string_pairs(enzymes[1:], strings, partial):
        yield p, s

cache = {}

def process(base_string, strings, depth = 0):
    global cache

    if depth > max_depth:
        return False

    if strings.count(base_string) > 1:
        return True

    if len(cache) > 100000:
        cache = {}

    strings.sort()

    k = (depth, "*".join(strings))
    if k in cache:
        return False
    cache[k] = True

    if verbose:
        print depth, strings

    enzymes = []
    for s in strings:
        enzymes.extend(ribosome(s))
    for pairs, new_strings in enzyme_string_pairs(enzymes, strings):
        if verbose:
            print depth, "->", pairs, new_strings
        for enzyme, string, n in pairs:
            strand = string_to_strand(string)
            unit = binding_site(enzyme, strand, n)
            assert unit is not None
            new_strings.extend(enzyme.translate(strand, unit))
        if process(base_string, new_strings, depth + 1):
            return True

def is_replicator(string):
    return process(string, [string])

def find_replicator(start = ""):
    start_time = time.time()
    for base_string in strings_gen(start):
        print "%06.f %s" % (time.time() - start_time, base_string)
        sys.stdout.flush()
        if is_replicator(base_string):
            print "*** %s is a self-rep" % base_string
            break

def usage():
    print """\
Usage: %(prog)s [option] [start]
Search for typogenetic self-reps or test if a string is a self-rep.
start is the string at which to start searching, or the string to test with
--single. It defaults to the empty string. When searching, after AAAA comes
AAAC, and so on. The program prints every string it tries and quits when it
finds a replicator.

  -h, --help     display this help
  -n, --num      how many generations to simulate (default 2)
  -t, --test     test only the single given string and exit
  -v, --verbose  show verbose details of string evolution and backtracking

Examples:
  %(prog)s AAAAAA
  %(prog)s --test ACTTCG
  %(prog)s --test --verbose CGTTTTTTTG\
""" % {"prog": sys.argv[0]}

def usage_error(s):
    print >> sys.stderr, "%s: %s" % (sys.argv[0], s)
    print >> sys.stderr, "Try '%s -h' for help. " % sys.argv[0]
    sys.exit(1)

def main():
    global max_depth, one_shot, verbose, start_string

    try:
        opts, args = getopt.gnu_getopt(sys.argv[1:], "hn:tv", ["help", "num=", "test", "verbose"])
    except getopt.GetoptError, e:
        usage_error(e.msg)
    for o, a in opts:
        if o == "-h" or o == "--help":
            usage()
            sys.exit(0)
        elif o == "-n" or o == "--num":
            max_depth = int(a)
        elif o == "--test":
            one_shot = True
        elif o == "-v" or o == "--verbose":
            verbose = True

    if len(args) > 1:
        usage_error("Only one start string is allowed (got %d)" % len(args))
    if len(args) == 1:
        start_string = args[0]

    if one_shot:
        if len(start_string) == 0:
            usage_error("--test requires a string to test.")
        if is_replicator(start_string):
            print "*** %s is a self-rep" % start_string
            sys.exit(0)
        else:
            print "%s is not a self-rep" % start_string
            sys.exit(1)
    else:
        find_replicator(start_string)
        sys.exit(0)

if __name__ == "__main__":
    main()

# An example of using the functions in this program:
# s = "AATCCG"
# e = ribosome(s)[0]
# strand = string_to_strand(s)
# unit = binding_site(e, strand, 0)
# print e.translate(strand, unit)
