#!/usr/bin/env python3

import os
import sys
import argparse
import subprocess
from concurrent import futures

SET_PATH = "/storage/bolidozor/support/bzremote_sets/%s"

DESCRIPTION = """
bzremote runs commands in bulk on a selected group of Bolidozor stations.

A group of remote hosts to which the command will be applied is read from a file.
By default, the path to this file is

    /storage/bolidozor/support/bzremote_sets/default

where the file name 'default' can be changed by the --set flag. Each line in
the file represents one hostname.

The flags are followed by a template for the commands to run. A command is run
per every remote, each time the command is obtained by replacing the special
character % in the command template with the remote's hostname.

The commands are run localy. Embed ssh or scp into the template to connect
to the remote host.
"""

EXAMPLES = """
Examples
--------

To distribute the admin login key to all stations, run:

    $ ./bzremote --threads 1 ssh-copy-id -i /etc/ssh/ssh_rmds_admin_key %

.
"""

def main():
    parser = argparse.ArgumentParser(description=DESCRIPTION, epilog=EXAMPLES,
                                     formatter_class=argparse.RawDescriptionHelpFormatter)

    parser.add_argument('--set', metavar='NAME', default='default', type=str,
                        help='name of the set of remotes to apply the command to ' \
                             '(default: \'default\')')
    parser.add_argument('--threads', default=4, type=int, 
                        help='maximum no. of concurrently running instances of the command')
    parser.add_argument('command', metavar='CMD', type=str, nargs='...',
                        help='the template of the command to run. for each instance of ' \
                             'the command, %% will be replaced with the remote hostname')
    args = parser.parse_args()
        
    if not args.command:
        print('bzremote: need a command to run', file=sys.stderr)
        sys.exit(1)

    remotes = [
        (hostname.strip(), [m.replace('%', hostname.strip()) for m in args.command])
        for hostname in open(SET_PATH % args.set, 'r') if hostname.strip() != ""
    ]

    exit_codes = []
    if args.threads <= 1:
        for hostname, cmd in remotes:
            print("bzremote: running '%s'" % " ".join(cmd), file=sys.stderr)
            result = subprocess.run(cmd)
            print("bzremote: %s finished (exit code: %d)" \
                  % (hostname, result.returncode), file=sys.stderr)
            exit_codes.append((hostname, result.returncode))
    else:
        exit_codes = []
        with futures.ThreadPoolExecutor(max_workers=args.threads) as executor:
            futu_to_hostname = {executor.submit(subprocess.run, cmd, stderr=subprocess.STDOUT,
                                                stdout=subprocess.PIPE): hn \
                                for hn, cmd in remotes}
            for future in futures.as_completed(futu_to_hostname):
                result = future.result()
                hostname = futu_to_hostname[future]
                #print("bzremote: %s finished (exit code: %d)" \
                #      % (hostname, result.returncode), file=sys.stderr)

                output = result.stdout.strip()
                if output != "":
                    try:
                        print("%s: %s" % (hostname, output.decode('UTF-8')))
                    except:
                        print("%s: %s" % (hostname, output))
                exit_codes.append((hostname, result.returncode))

    nonzero = [hn for hn, code in exit_codes if code != 0]
    if not nonzero:
        print("bzremote: all commands on %d remotes exited successfuly"
              % len(remotes), file=sys.stderr)
    else:
        print("bzremote: %d out of %d commands finished with non-zero exit codes" \
              % (len(nonzero), len(remotes)), file=sys.stderr)
        print("bzremote: these were for hostnames:", file=sys.stderr)
        for hostname in nonzero:
            print("\t" + hostname, file=sys.stderr)

if __name__ == "__main__":
    main()

