#!/usr/bin/python3

import re
import os
import sys

error_lines = []


class Part(object):
    """Represents a part in the calculation graph

    """

    def __init__(self, part_id: str):
        """
        :param part_id: A PartID.
        """
        if not isinstance(part_id, str):
            raise Exception("id is not string")
        self.id = part_id
        self.to_nodes = []
        self.from_nodes = []
        self.visited = False
        self.type = None
        self.none = None

        # this node has been explicitly read.
        # Each node should be read explicitly once and only once.
        self.explicit = False

    def add_link_to(self, to_node):
        """
        Link two Part nodes, so that 'self' is the from node
        :param to_node: A Part that is the "to node" of the new link
        :return: None
        """
        if to_node not in self.to_nodes:
            self.to_nodes.append(to_node)
        if self not in to_node.from_nodes:
            to_node.from_nodes.append(self)

    def clear_visit(self):
        self.visited = False

    def is_ipcx_receive(self):
        """

        :return: return True if this is an IPCx receiver
        """
        if self.type == "IPCx":
            for node in self.from_nodes:
                if node.type == "GROUND":
                    return True
        return False

    def __repr__(self):
        return f"Part('{self.id}')"


class CalcGraph(object):
    """
    Represents the Directed Acyclic Graph (DAG) of calculations generated by the RCG for a model
    """

    def __init__(self):
        # dictionary of part name (string) : Part
        self.nodes = {}
        # these nodes are data sources
        # they either have no inputs that we care about
        self.root_nodes = set()

    def add(self, part_id):
        """
        Add an as yet unconnected Part node to the graph
        :param part_id: PartID, The part_id of the part
        :return: the new Part
        """
        self.nodes[part_id] = Part(part_id)
        self.root_nodes.add(self.nodes[part_id])
        return self.nodes[part_id]

    def add_explicit(self, part_id, part_num, part_type):
        """
        Add a Part node.  Raise an exception if part was already explicitly added
        (as opposed to implicitly added through a link.)
        :param part_id: PartID of new Part
        :param part_type: string.  part_type of new part
        :param part_num: int.  Number of part.
        :return: the new Part
        """
        if part_id in self.nodes:
            if self.nodes[part_id].explicit:
                error_lines.append(f"{self.nodes[part_id]} was already explicitly added to the graph.")
                return self.nodes[part_id]
        else:
            self.add(part_id)
        self.nodes[part_id].explicit = True
        self.nodes[part_id].type = part_type
        self.nodes[part_id].num = part_num
        return self.nodes[part_id]

    def find_or_add(self, part_id):
        """
        Return a node if it already exists, or implicitly create one.
        :param part_id: PartID of the node to return
        :return: the node
        """
        if part_id in self.nodes:
            return self.nodes[part_id]
        else:
            return self.add(part_id)

    def check_sequence(self, sequence):
        """
        Compare a sequence to the graph to see if it's a valid calculation sequence.
        If the seqeunce traverses the nodes so that all inputs to a node are visited before
        the node, then the sequence is properly ordered.

        :param sequence: An iterable sequence of PartIDs that represent a calculation
        sequence that should satisfy the graph
        :return: True, [] if sequence is good. otherwise  False, [(Part1, Part2) ...].
        For each tuple returned on failure, Part1 should have been calculated after Part2, but was not.
        """
        wrong_order = []
        for part_id in sequence:
            if part_id not in self.nodes:
                error_lines.append(f"Could not find part id {part_id} in graph")
            part = self.nodes[part_id]
            if part.type == "GROUND":
                continue
            #ignore incoming IPCs, they are always sequenced first
            if part.is_ipcx_receive():
                continue
            # delays are always calculated last.
            # print(f"{part.id}: " + ",".join([p.id for p in part.to_nodes]))
            if part.type == "DELAY":
                continue
            for out in part.to_nodes:
                if out.visited:
                    error_lines.append(f"{out.id} ({out.num}) calculated before {part.id} ({part.num})")
                    wrong_order.append((out, part))
            part.visited = True
        if len(wrong_order) > 0:
            return False, wrong_order
        else:
            return True, []

    def add_link(self, from_node, to_node):
        """
        Create a link between the two nodes.  Update roots if necessary.

        :param from_node:
        :param to_node:
        :return:
        """
        if to_node in self.root_nodes:
            self.root_nodes.remove(to_node)
        from_node.add_link_to(to_node)

    def get_unvisited_parts(self):
        """
        :return: a list of all nodes that haven't been visited
        """
        return [n for n in self.nodes.values() if not n.visited]


def read_sequence(model):
    """
    Read in sequence file.  Return list of PartID tuples.
    :param model: string, a model name
    :return: A list of PartID tuples.
    """
    fname = f"{model}_partSequence.txt"
    ret = []
    with open(fname, "rt") as f:
        for line in f.readlines():
            pair = line.strip().split(",")
            if len(pair) == 2:
                ret.append(pair[1])
    return ret


part_line_re = re.compile(r'Part (\d+) ([^\n]+) is type (\S+) with (\d*) inputs and (\d*) outputs')
link_re = re.compile(r'\s+(\S+(?: Name)?)\s+')
header_re = re.compile(r'\s+Part Name')


def read_graph(model):
    """
    Read in the DAG of the calculation model

    :param model: string, a model name
    :return:
    """
    global part_line_re, link_re, header_re
    fname = f"{model}_partConnectionList.txt"
    graph = CalcGraph()
    state = "look for part"
    current_node = None
    # old_state = None
    linecount = 0
    with open(fname, "rt") as f:
        for raw_line in f.readlines():
            linecount += 1
            # if old_state != state:
            #     print(f"{linecount}: state = {state}")
            #     old_state = state
            line = raw_line
            if state == "look for part":
                m = part_line_re.match(line)
                if m:
                    new_num = m.group(1)
                    new_name = m.group(2)
                    new_type = m.group(3)
                    num_inputs = m.group(4)
                    if len(num_inputs) > 0:
                        num_inputs = int(num_inputs)
                    else:
                        num_inputs = 0
                    num_outputs = m.group(5)
                    if len(num_outputs) > 0:
                        num_outputs = int(num_outputs)
                    else:
                        num_outputs = 0
                    if new_type not in "INPUT OUTPUT".split():
                        current_node = graph.add_explicit(new_name, new_num, new_type)
                        state = "look for inputs"
                        inputs_read = []
                        outputs_read = []
            elif state == "look for inputs":
                if line.strip() == "INS FROM:":
                    state = "read inputs"
            elif state == "read inputs":
                if line.strip() == "OUT TO:":
                    state = "read outputs"
                elif not header_re.match(line):
                    m = link_re.match(line)
                    if m:
                        in_name = m.group(1)
                        from_node = graph.find_or_add(in_name)
                        graph.add_link(from_node, current_node)
                        inputs_read.append(from_node.id)
                    else:
                        state = "look for outputs"
            elif state == "look for outputs":
                if line.strip() == "OUT TO:":
                    state = "read outputs"
            elif state == "read outputs":
                if not header_re.match(line):
                    m = link_re.match(line)
                    if m:
                        out_name = m.group(1)
                        to_node = graph.find_or_add(out_name)
                        graph.add_link(current_node, to_node)
                        outputs_read.append(to_node.id)
                    else:
                        # don't check input counts on DAC since they are hard coded
                        # no matter the actual input numbers
                        if current_node.type not in "Dac Dac18 Dac20 BUSS".split(" "):
                            if num_inputs != len(inputs_read):
                                from_nodes = "\n".join([str(i) for i in current_node.from_nodes])
                                error_lines.append(f"""
        Wrong number of inputs for {current_node}
        Expected {num_inputs} but read {len(current_node.from_nodes)}:
        {from_nodes}
            """)
                        if num_outputs != len(outputs_read):
                            to_nodes = "\n".join([str(i) for i in current_node.to_nodes])
                            error_lines.append(f"""
    Wrong number of outputs for {current_node}
    Expected {num_outputs} but read {len(current_node.to_nodes)}:
    {to_nodes}
                            """)
                        state = "look for part"
            else:
                raise Exception(f"unknown state '{state}'")
    return graph


def check_unvisited(graph: CalcGraph):
    """
    Generate errors for any nodes that should have been visited on a graph

    :param graph:
    :return: None
    """
    unused_types = "EXC StateWord ModelRate DELAY EpicsIn TERM GROUND Adc BUSS BUSC CONSTANT Gps Parameters GOTO FROM OUTPUT INPUT".split()
    unvis = [n for n in graph.get_unvisited_parts()
             if not (n.type is None or n.type in unused_types or n.is_ipcx_receive())]
    for node in unvis:
        error_lines.append(f"{node.id} ({node.type}) was not visited.")


seqfile_re = re.compile(r'(\w+)_partSequence.txt')


def process_one(model):
    """

    :param model:
    :return: True on success
    """
    global error_lines
    error_lines = []
    print(f"processing {model}")
    seq = read_sequence(model)
    calc_graph = read_graph(model)

    seq_good, bad = calc_graph.check_sequence(seq)

    check_unvisited(calc_graph)

    out_name = f"{model}_sequenceErrors.txt"
    with open(out_name, "wt") as f:
        for line in error_lines:
            f.write(line + "\n")

    return len(error_lines) == 0


def process_all():
    """
    Process all <model>_partSequence.txt files in the current directory.
    :return: None
    """
    files = os.listdir()
    modelnames = []
    for file in files:
        m = seqfile_re.match(file)
        if m:
            modelnames.append(m.group(1))

    problem_models = []
    for model in modelnames:
        if not process_one(model):
            problem_models.append(model)
    print("problem models:")
    print("\n".join(problem_models))


def run_from_args():
    """
    Get a model name from the command line.  Exit with an error code on failure
    :return:
    """
    if len(sys.argv) < 3:
        print(f"usage: {sys.argv[0]} <working_dir> <model-name>")
        print(f"Only {len(sys.argv)} arguments")
        sys.exit(1)
    workdir = sys.argv[1]
    model = sys.argv[2]
    os.chdir(workdir)
    if process_one(model):
        sys.exit(0)
    else:
        sys.exit(3)


if __name__ == '__main__':
    run_from_args()

# See PyCharm help at https://www.jetbrains.com/help/pycharm/
