#!/usr/bin/python3
import argparse
import common
import json
import logging
import os
import pwd
import re
import subprocess
import sys
import glob

from stat import *
from xml.dom.minidom import parse, parseString, Node

def command_requires_pkg(command):
    return command in ['default', 'list', 'set', 'set-force', 'unset', 'unset-force', 'validate']

def validate_environment():
    if os.getuid() != 0:
        common.fatal_error("this program must be run as root")
    common.ls_ok()
    
def init_pgm():
    common.init_logging()
    parser = argparse.ArgumentParser(prog="lspkgctl",
                                     description='LiteSpeed Packages Control Program')
    parser.add_argument("command", type=str, help="The packages command", 
                        choices=['default', 'list', 'list-all', 'set', 'set-force', 'unset', 'unset-force', 'validate', 'validate-all', 'version'])
    parser.add_argument("pkg", type=str, nargs='?', default=None, help="package name")
    parser.add_argument("--cpu", type=str, help="limit CPU access.  Specify a percentage of a complete core, 100 is 1 complete core.  Default is no limit (-1).")
    parser.add_argument("--io", type=str, help="limit I/O access.  Specify in bytes/sec.  Default is no limit (-1).")
    parser.add_argument("--iops", type=str, help="limit I/O access.  Specify in IOs per second.  Default is no limit (-1).")
    parser.add_argument('-l', '--log', type=int, help='set logging level, 10=Debug, 20=Info, 30=Warning, 40=Error.  Default is Info')
    parser.add_argument("--mem", type=str, help="limit virtual memory access.  Specify in bytes.  Default is no limit (-1).")
    parser.add_argument('-q', '--quiet', action='store_true', help='turns off all logging and only outputs what is requested.')
    parser.add_argument("--tasks", type=str, help="limit number of tasks.  Specify a maximum count of tasks that can be started.")
    args = parser.parse_args()
    if not args.quiet or args.log != None:
        if args.log != None:
            logging.getLogger().setLevel(args.log)
        else:
            logging.getLogger().setLevel(logging.INFO)
        logging.debug("Entering lspkgctl")

    validate_environment()
    command = args.command
    if command_requires_pkg(command):
        if args.pkg == None:
            common.fatal_error("You must specify a package for the %s command" % command)
    if command == 'version':
        logging.info("Version: %s" % common.VERSION)
        if args.quiet:
            print(common.VERSION)
        sys.exit(0)
    if (command.startswith('set')) and (args.cpu == None and args.io == None and args.iops == None and args.mem == None and args.tasks == None):
        common.fatal_error("You specified command: %s and no qualifier to set with it" % command)
    return args

def run_program(args, fail_reason = ""):
    logging.debug('run: ' + str(args))
    result = subprocess.run(args, capture_output=True, text=True)
    if fail_reason != "" and result.returncode == 0:
        common.fatal_error('Expected: ' + args[0] + ' to fail: ' + fail_reason)
    if fail_reason == "" and result.returncode != 0:
        common.fatal_error('Error in running: ' + args[0] + ' errors: ' + result.stdout + ' ' + result.stderr)
    return result.stdout

def read_pkg(pkg):
    try:
        f = open(common.pkg_to_filename(pkg), 'r', encoding="utf-8")
    except Exception as err:
        logging.warning('Error opening %s: %s - reinitializing' % (common.pkg_to_filename(pkg), err))
        return initialize_pkg(pkg, None)
    try:
        pkg_json = json.load(f)
    except Exception as err:
        logging.warning('Error reading %s: %s - reinitialzing' % (common.pkg_to_filename(pkg), err))
        return initialize_pkg(pkg, None)
    finally:
        f.close()
    if not 'uids' in pkg_json:
        common.fatal_error("Missing uids in package: %s" % pkg)
    pkg_json['uids'].sort()
    return pkg_json

def list_pkg(pkg):
    pkg_json = read_pkg(pkg)
    return pkg_json

def command_list(args, single_list=False):
    pkg_obj = {}
    if single_list or args.command == 'list':
        pkg_obj[args.pkg] = list_pkg(args.pkg)
    else:
        pkgs = glob.glob(common.pkg_to_filename('*'))
        for f in pkgs:
            fileonly = os.path.basename(f)
            pkg = fileonly.split('.conf')[0]
            logging.debug('From file: %s, using pkg %s' % (f, pkg))
            pkg_obj[pkg] = list_pkg(pkg)
    final_json = {}
    final_json['packages'] = pkg_obj
    print(json.dumps(final_json, indent=4))
    return 0

def plesk_home_dict():
    home_dict = {}
    for user in pwd.getpwall():
        if user.pw_uid >= common.get_min_uid():
            logging.debug('home_dict key: ' + user.pw_dir)
            home_dict[user.pw_dir] = user
    return home_dict

def read_pleskplans():
    home_dict = plesk_home_dict()
    plans = {}
    plan_lines = run_program(['plesk', 'db', \
                              "SELECT sys_users.login, sys_users.home, " + \
                                "Templates.name from sys_users join " +  
                                "PlansSubscriptions on " + 
                                "PlansSubscriptions.subscription_id = sys_users.id " + 
                                "join Templates on " + 
                                "PlansSubscriptions.plan_id = Templates.id"])
    for line in plan_lines.splitlines():
        if line.startswith('+-') or line.startswith('| login'):
            continue
        pat = re.compile( r'\|\s(.*?)\s{1,}\|')
        matches = []
        pos = 0
        while (match := pat.search(line, pos)) is not None:
            logging.debug('Appending: %s in %s' % (match[1], line))
            pos = match.start() + 1
            matches.append(match[1])
        if len(matches) != 3:
            for match in matches:
                logging.debug(match)
            common.fatal_error("Expected 3 matches in %s and got %d" % (line, len(matches)))
        if not matches[1] in home_dict:
            common.fatal_error("%s in user dictionary from %s" % (matches[1], line))
        plan = matches[2]
        user = home_dict[matches[1]]
        logging.debug("   add to %s: %s (%d)" % (plan, user.pw_name, user.pw_uid))
        if not plan in plans:
            plans[plan] = []
        plans[plan].append(user.pw_uid)
    return plans

def read_userplans():
    if common.get_plesk():
        return read_pleskplans()
    plans = {}
    planfiles = glob.glob('/var/cpanel/packages/*')
    for planfile in planfiles:
        if os.path.isfile(planfile):
            plan = os.path.basename(planfile)
            plans[plan] = []
            logging.info("Add package: %s" % plan)
    try:
        f = open('/etc/userplans', 'r', encoding="utf-8")
    except Exception as err:
        common.fatal_error('Error opening /etc/userplans: %s' % err)
    lines = f.readlines()
    for line in lines:
        if line[0] != '#':
            pieces = line.split(': ')
            if len(pieces) == 2:
                pkg = pieces[1].strip(' \n\t')
                user = pieces[0].strip()
                logging.debug('   pkg: %s, user: %s' % (pkg, user))
                if pkg in plans:
                    try:
                        pwd_var = pwd.getpwnam(user)
                    except Exception as err:
                        common.fatal_error("passwd entry error: %s for user %s" % (err, user))
                    plans[pkg].append(pwd_var.pw_uid)
                    logging.debug("   add to %s: %s (%d)" % (pkg, user, pwd_var.pw_uid))
    f.close()
    return plans
    
def option_default(opt, pkg_json):
    if opt in pkg_json and pkg_json[opt] != -1:
        return False
    return True

def pkg_json_defaults(pkg_json):
    for opt in common.get_options():
        if not option_default(opt, pkg_json):
            return False
    return True

def set_users_to_pkg(users, pkg_json):
    defaults = True
    args = [common.get_bin_file('lscgctl')]    
    for opt in common.get_options():
        if not option_default(opt, pkg_json):
            args.append('--%s' % opt)
            args.append(pkg_json[opt])
            defaults = False
        else:
            args.append('--%s' % opt)
            args.append('-1')
    if defaults:
        args.append('unlimited')
    else:
        args.append('set')
    if len(users) == 0:
        logging.debug("Do not set users and none to set")
        return
    for user in users:
        args.append(str(user))
    run_program(args)

def write_pkg(pkg, pkg_json):
    logging.debug('Write pkg: %s as %s' % (pkg, common.pkg_to_filename(pkg)))
    try:
        f = open(common.pkg_to_filename(pkg), 'w', encoding="utf-8")
    except Exception as err:
        common.fatal_error('Error opening %s for write: %s' % (common.pkg_to_filename(pkg), err))
    try:
        f.write(json.dumps(pkg_json, indent=4))
    except Exception as err:
        common.fatal_error('Error writing %s: %s' % (common.pkg_to_filename(pkg), err))
    f.close()
    
def revalidate_pkg(pkg, users, pkg_json):
    rewrite_json = False
    for user in users:
        if not user in pkg_json['uids']:
            logging.info("Add user: %s to package: %s" % (user, pkg))
            pkg_json['uids'].append(user)
            rewrite_json = True
    if rewrite_json:
        set_users_to_pkg(users, pkg_json)
    for uid in pkg_json['uids']:
        if not uid in users:
            logging.info("Remove user: %s from package: %s" % (uid, pkg))
            pkg_json['uids'].remove(uid)
            rewrite_json = True
    if rewrite_json:
        write_pkg(pkg, pkg_json)
    return pkg_json

def initialize_pkg(pkg, users):
    if users is None:
        plans = read_userplans()
        users = plans[pkg]
    pkg_json = {}
    if len(users) == 0:
        return
    pkg_json['uids'] = []
    for user in users:
        pkg_json['uids'].append(user)
    write_pkg(pkg, pkg_json)
    return pkg_json

def validate_pkg(pkg, users):
    if users is None:
        plans = read_userplans()
        users = plans[pkg]
    if os.path.isfile(common.pkg_to_filename(pkg)):
        pkg_json = read_pkg(pkg)
        pkg_json = revalidate_pkg(pkg, users, pkg_json)
    else:
        pkg_json = initialize_pkg(pkg, users)
    return pkg_json
    
def command_validate(args):
    plans = read_userplans()
    if args.command == 'validate-all' or args.command == 'list-all':
        for plan in plans:
            plans[plan] = validate_pkg(plan, plans[plan])
    else:
        if not args.pkg in plans:
            common.fatal_error("You specified package %s and it doesn't exist" % args.pkg)
        plans[args.pkg] = validate_pkg(args.pkg, plans[args.pkg])
    return 0

def set_json(args, pkg_json):
    changed = False
    args_dict = vars(args)
    logging.debug('In set_json')
    for opt in common.get_options():
        if args_dict[opt] != None:
            logging.debug('Arg %s = %s' % (opt, args_dict[opt]))
            if not opt in pkg_json or pkg_json[opt] != args_dict[opt]:
                pkg_json[opt] = args_dict[opt]
                changed = True
            else:
                logging.debug('   not changed')
        elif opt in pkg_json:
            logging.debug('   delete it, if there')
            if opt in pkg_json:
                del pkg_json[opt]
                logging.debug('   deleted')
            changed = True
        else:
            logging.debug('Arg %s not specified' % (opt))

    if not changed:
        return False
    return True

def get_defaults_users(pkg_json):
    pgm_args = [common.get_bin_file('lscgctl'), 'list-all', 'q']    
    out = run_program(pgm_args)
    try:
        users_json = json.loads(out)
    except Exception as err:
        common.fatal_error('Error parsing user details: %s' % err)
    defaults_users = []
    logging.debug('get_defaults_users for package')
    for user in pkg_json['uids']:
        if str(user) in users_json:
            user_ok = True
            user_opts = users_json[str(user)]
            for opt in common.get_options():
                if user_opts.get(opt) != pkg_json.get(opt) and \
                   (user_opts.get(opt) != '' or pkg_json.get(opt) != None) and \
                   (pkg_json.get(opt) == None or user_opts.get(opt).lower() != pkg_json.get(opt).lower()):
                    logging.debug('For user: %s opt %s different, %s != %s' % (str(user), opt, user_opts.get(opt), pkg_json.get(opt)))
                    user_ok = False
                    break
            if not user_ok:
                continue
            logging.debug('Add user %s' % str(user))
            defaults_users.append(user)
    return defaults_users

def command_set(args):
    pkg_json = read_pkg(args.pkg)
    if 'force' in args.command or 'default' == args.command:
        uids = pkg_json['uids']
    else:
        uids = get_defaults_users(pkg_json)
    if 'set' in args.command:
        if not set_json(args, pkg_json):
            return 0
        write_pkg(args.pkg, pkg_json)
    set_users_to_pkg(uids, pkg_json)
    return 0

def do_pgm(args):
    logging.debug("Entering lspkgctl, command: %s" % args.command)
    if 'validate' in args.command or 'list' in args.command:
        ret = command_validate(args)
        if 'list' in args.command:
            ret = command_list(args)
    elif 'set' in args.command or 'default' in args.command:
        ret = command_set(args)
    elif 'unset' in args.command:
        ret = command_set(args)
    else:
        common.fatal_error('Unexpected command: %s' % args.command)
    logging.debug("Exiting lspkgctl")
    return ret

def main():
    args = init_pgm()
    return do_pgm(args)
  
if __name__ == "__main__":
    main()
