#!/usr/bin/python3

import argparse
import os
import time

import psutil
from pcaspy import Driver, SimpleServer


def base_database():
    return {
        'CPU_LOAD_PERCENT': {
            'prec': 3,
            'unit': 'percent',
            'scan': 1,
            'hihi': 90,
            'high': 80,
            'low': -1,
            'lolo': -1,
        },
        'CPU_COUNT': {
            'prec': 1,
            'unit': 'cores',
            'scan': 1,
            'type': 'int',
        },
        'LOAD_AVERAGE': {
            'scan': 1,
            'type': 'string',
        },
        'MEMORY_AVAIL_MB': {
            'type': 'int',
            'prec': 0,
            'unit': 'MB',
            'scan': 1,
        },
        'MEMORY_AVAIL_PERCENT': {
            'prec': 3,
            'unit': 'percent',
            'scan': 1,
            'hihi': 101,
            'high': 101,
            'low': 20,
            'lolo': 10,
        },
        'INET_CONNECTIONS': {
            'prec': 0,
            'unit': 'counts',
            'type': 'int',
            'scan': 1,
        },
        'PROCESSES': {
            'prec': 0,
            'unit': 'counts',
            'type': 'int',
            'scan': 1,
        },
        'NET_TX_TOTAL_MBIT': {
            'prec': 3,
            'unit': 'Mb/s',
            'type': 'float',
            'scan': 1,
        },
        'NET_RX_TOTAL_MBIT': {
            'prec': 3,
            'unit': 'Mb/s',
            'type': 'float',
            'scan': 1,
        }
    }


def network_interface_pvname_sent(name):
    return 'NET_TX_{0}_MBIT'.format(name.upper())


def network_interface_pvname_recv(name):
    return 'NET_RX_{0}_MBIT'.format(name.upper())


def add_network_interfaces(network_counters, db):
    for name in network_counters.interfaces():
        db[network_interface_pvname_sent(name)] = {
            'prec': 3,
            'unit': 'Mb/s',
            'type': 'float',
            'scan': 1,
        }
        db[network_interface_pvname_recv(name)] = {
            'prec': 3,
            'unit': 'Mb/s',
            'type': 'float',
            'scan': 1,
        }
    return db


class NetworkMonitor(object):
    def __init__(self, period=5, name_mapping={}, poll_iface=None):
        self.__counter_sent = {}
        self.__counter_recv = {}
        self.__rate_sent = {}
        self.__rate_recv = {}
        self.__total_counter_sent = 0
        self.__total_counter_recv = 0
        self.__total_rate_sent = 0
        self.__total_rate_recv = 0
        self.__period = period
        self.__last_update_time = 0
        self.__poll_iface = poll_iface
        self.__is_polling = False # true if in the middle of polling for updated network counters
        self.__start_rx = 0 # hold starting rx counter value when polling
        self.__start_time = 0 # hold start time when polling
        self.__name_mapping = name_mapping
        self.__interfaces = self.__net_io_counters().keys()

    def __net_io_counters(self):
        return {self.__name_mapping.get(k, k): v for (k, v) in psutil.net_io_counters(pernic=True, nowrap=True).items()}

    def interfaces(self):
        return self.__interfaces

    def get_rate_sent(self, iface: str):
        if iface == '':
            return self.__total_rate_sent
        return self.__rate_sent.get(iface.lower(), 0)

    def get_rate_recv(self, iface: str):
        if iface == '':
            return self.__total_rate_recv
        return self.__rate_recv.get(iface.lower(), 0)

    def poll_for_change(self) :
        """wait until recv on iface has changed before returning net counters
        """
        iface = self.__poll_iface
        counters = self.__net_io_counters()
        rx = counters[iface].bytes_recv
        if self.__is_polling:
            if rx != self.__start_rx or time.time() - self.__start_time > 10:
                self.__is_polling = False
        else:
            self.__start_rx = counters[iface].bytes_recv
            self.__start_time = time.time()
            self.__is_polling = True
        if self.__is_polling:
            return None
        return counters

    def get_counters(self):
        if self.__poll_iface:
            return self.poll_for_change()
        else:
            return self.__net_io_counters()

    def update_all(self):        
        data = self.get_counters()
        if data is None: # Don't update if we're polling.
            return False
        now = time.time()
        period = now - self.__last_update_time
        self.__last_update_time = now
        total_sent = 0
        total_recv = 0
        for iface in data:
            sent_counter = data[iface].bytes_sent
            recv_counter = data[iface].bytes_recv

            total_sent += sent_counter
            total_recv += recv_counter

            last_sent = self.__counter_sent.get(iface, 0)
            last_recv = self.__counter_recv.get(iface, 0)

            rate_sent = 0
            rate_recv = 0
            if last_sent != 0:
                rate_sent = (sent_counter - last_sent) // period
            if last_recv != 0:
                rate_recv = (recv_counter - last_recv) // period

            self.__counter_sent[iface] = sent_counter
            self.__counter_recv[iface] = recv_counter

            self.__rate_sent[iface] = rate_sent
            self.__rate_recv[iface] = rate_recv

        if self.__total_counter_sent != 0:
            self.__total_rate_sent = (total_sent - self.__total_counter_sent) // period
        if self.__total_counter_recv != 0:
            self.__total_rate_recv = (total_recv - self.__total_counter_recv) // period
        self.__total_counter_sent = total_sent
        self.__total_counter_recv = total_recv
        return True

def read_recv_mbit(counters, iface):
    def do_read_recv_mbit():
        # print("{0} - {1}mb".format(iface, counters.get_rate_recv(iface)/(1024*1024.0)))
        return 8 * counters.get_rate_recv(iface) / (1024.0 * 1024.0)

    return do_read_recv_mbit


def read_sent_mbit(counters, iface):
    def do_read_sent_mbit():
        return 8 * counters.get_rate_sent(iface) / (1024.0 * 1024.0)

    return do_read_sent_mbit


class MonitorDriver(Driver):
    def __init__(self):
        super(MonitorDriver, self).__init__()
        self.__network_counters = None
        self.__lookup = {}

    def setNetworkInterfaces(self, network_counters):
        self.__network_counters = network_counters
        self.__lookup = {
            'NET_RX_TOTAL_MBIT': read_recv_mbit(self.__network_counters, ''),
            'NET_TX_TOTAL_MBIT': read_sent_mbit(self.__network_counters, ''),
        }
        for name in network_counters.interfaces():
            self.__lookup[network_interface_pvname_recv(name)] = read_recv_mbit(
                self.__network_counters, name)
            self.__lookup[network_interface_pvname_sent(name)] = read_sent_mbit(
                self.__network_counters, name)

    def read(self, reason):
        if reason == 'CPU_LOAD_PERCENT':
            value = psutil.cpu_percent()
        elif reason == 'CPU_COUNT':
            value = psutil.cpu_count(logical=True)
        elif reason == 'LOAD_AVERAGE':
            try:
                value = "{0}".format(os.getloadavg())
            except OSError:
                value = " - "
        elif reason == 'MEMORY_AVAIL_PERCENT':
            value = 100.0 - psutil.virtual_memory().percent
        elif reason == 'MEMORY_AVAIL_MB':
            value = int(psutil.virtual_memory().available / (1024 * 1024))
        elif reason == 'INET_CONNECTIONS':
            value = len(psutil.net_connections())
        elif reason == 'PROCESSES':
            value = len(psutil.pids())
        elif reason in self.__lookup:
            value = self.__lookup[reason]()
        else:
            value = self.getParam(reason)
        return value


def parse_name_mapping(mappings):
    results = {}
    for entry in mappings:
        parts = entry.split('=', 1)
        results[parts[0]] = parts[1].lower()
    return results


def main():
    parser = argparse.ArgumentParser(
        description="A simple system monitor that reflects its stats into EPICS",
    )
    parser.add_argument(
        '-c', '--chans', action='store_true',
        help="print EPICS channels and exit",
    )
    parser.add_argument(
        '--net-period', default=5, type=float,
        help="set the polling period for the network interfaces"
    )
    parser.add_argument('--rename-interfaces', nargs='+',
                        default=[],
                        help='iface=new_name pairs to allow mapping interface to high level functional names')
    parser.add_argument(
        'prefix',
        help="Prefix of epics variables",
    )
    parser.add_argument(
        '--poll-iface', default="", type=str,
        help="poll this interface to get more precise transfer rates at the cost of some cpu usage"
    )
    args = parser.parse_args()

    name_mapping = parse_name_mapping(args.rename_interfaces)

    net_period = args.net_period
    poll_iface = args.poll_iface

    network_counters = NetworkMonitor(net_period, name_mapping, poll_iface)

    pvdb = add_network_interfaces(network_counters, base_database())

    for entry in pvdb:
        print("{0}{1}".format(args.prefix, entry))
    if args.chans:
        parser.exit()

    server = SimpleServer()
    server.createPV(args.prefix, pvdb)
    driver = MonitorDriver()
    driver.setNetworkInterfaces(network_counters)

    now = time.time()
    check_network_at = now + net_period
    network_counters.update_all()

    fast_timeout = 0.001
    slow_timeout = 0.1
    server_timeout = slow_timeout

    while True:
        now = time.time()
        if now >= check_network_at:
            server_timeout = fast_timeout
            updated = network_counters.update_all()
            if updated:
                while check_network_at <= now:
                    check_network_at += net_period
                server_timeout = slow_timeout
        server.process(server_timeout)


main()
