#!/bin/python3

# Copyright 2025, 2026 eomanis
#
# This file is part of coturn-babysitter.
#
# coturn-babysitter is free software: you can redistribute it and/or
# modify it under the terms of the GNU General Public License version 3
# as published by the Free Software Foundation.
#
# coturn-babysitter is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with coturn-babysitter.  If not, see
# <http://www.gnu.org/licenses/>.

import argparse
import os
import shutil
import subprocess
import sys
from io import StringIO

import dns.resolver
from dns.rdatatype import RdataType

VERSION = (0, 0, 3, "")
# TODO Find out where the __version__ dunder goes and use that instead
VERSION_STR = ".".join(map(str, VERSION[:3])) + ("-" if VERSION[3] else "") + VERSION[3]

# Do not change these strings willy-nilly
# Changing them will break forward compatibility
COTURN_CONFIG_CUTOFF_MARKER = "#### COTURN-BABYSITTER CUTOFF MARKER ####"
PLACEHOLDER_PUBLIC_IPV4_ADDR = "_COTURN_BABYSITTER_PUBLIC_IPV4_ADDR_"
PLACEHOLDER_PUBLIC_IPV6_ADDR = "_COTURN_BABYSITTER_PUBLIC_IPV6_ADDR_"

DESCRIPTION = "coturn configuration rewriter and service reloader"
EPILOG = (
    "Supported placeholders in the template file:\n"
    f" - {PLACEHOLDER_PUBLIC_IPV4_ADDR}\n"
    f" - {PLACEHOLDER_PUBLIC_IPV6_ADDR}"
)

TEMPLATE_DEFAULT = "/etc/coturn-babysitter/turnserver.conf.template"
COTURN_CONFIG_DEFAULT = "/etc/turnserver/turnserver.conf"
SERVICE_DEFAULT = "turnserver.service"

verbose = False
# TODO Figure out proper logging, i. e. a clean separation between the
# CoturnBabysitter class and the messaging methods


class CoturnBabysitter:
    def __init__(
        self,
        domain_name: str,
        template_file: str,
        coturn_config_file: str,
        coturn_config_file_tmp: str,
        coturn_config_file_bak: str,
        coturn_service: str,
        reload: bool,
        dry_run: bool,
    ) -> None:
        self.domain_name: str = domain_name
        self.template_file: str = template_file
        self.coturn_config_file: str = coturn_config_file
        self.coturn_config_file_tmp: str = coturn_config_file_tmp
        self.coturn_config_file_bak: str = coturn_config_file_bak
        self.coturn_service: str = coturn_service
        self.reload: bool = reload
        self.dry_run: bool = dry_run

    def run(self) -> None:
        """Generates a new coturn configuration file in memory based on
        the existing configuration file and the template.

        If the new configuration differs from the one in the existing file,
        replaces the existing file with the new configuration and, if reload is
        True, reloads the coturn systemd service."""

        # Print some info about what is going on
        info(f"This is coturn-babysitter {VERSION_STR}")
        if self.dry_run:
            info("This is a dry run (no configuration replacing, no service reloading)")
        debug("Verbose output is enabled")
        debug(f"Domain name to resolve is '{self.domain_name}'")
        debug(f"coturn-babysitter template file is '{self.template_file}'")
        debug(f"coturn configuration file to edit is '{self.coturn_config_file}'")
        if self.coturn_config_file_tmp:
            debug(
                f"Temporary coturn configuration file is '{self.coturn_config_file_tmp}'"
            )
        else:
            debug("coturn configuration file will be rewritten in-place")
        if self.coturn_config_file_bak:
            debug(
                f"Backup coturn configuration file is '{self.coturn_config_file_bak}'"
            )
        else:
            debug("coturn configuration file will not be backed up")
        if self.reload:
            debug(
                f"systemd service will be reloaded if required: '{self.coturn_service}'"
            )
        else:
            debug("coturn systemd service will not be reloaded")

        # Determine the public IPv4 and IPv6 addresses for the domain name
        ipv4_addr = get_ip_addr(self.domain_name, RdataType.A)
        if ipv4_addr != "":
            debug(f"Public IPv4 address is {ipv4_addr}")
        else:
            debug("Could not resolve the domain name to an IPv4 address")

        ipv6_addr = get_ip_addr(self.domain_name, RdataType.AAAA)
        if ipv6_addr != "":
            debug(f"Public IPv6 address is {ipv6_addr}")
        else:
            debug("Could not resolve the domain name to an IPv6 address")

        # Read the current coturn configuration file into a string
        with open(self.coturn_config_file) as file:
            coturn_config = file.read()

        # Generate the new coturn configuration, also in a string
        coturn_config_new = self.get_new_coturn_config(
            coturn_config=coturn_config,
            ipv4_addr=ipv4_addr,
            ipv6_addr=ipv6_addr,
        )

        # Compare the new to the old configuration, and if it is different,
        # replace the old configuration with the new one and restart coturn
        if coturn_config_new != coturn_config:
            self.update_coturn_config(coturn_config_new)

            if self.reload:
                self.reload_coturn_service()

        else:
            debug("coturn configuration unchanged, doing nothing")

    def get_new_coturn_config(
        self,
        coturn_config: str,
        ipv4_addr: str,
        ipv6_addr: str,
    ) -> str:
        """Returns a new coturn configuration that is composed from the
        supplied coturn configuration, the contents of the template file
        and the given IPv4 and IPv6 addresses.
        """

        target = StringIO()
        # Copy everything from the existing coturn configuration until right
        # before the cutoff marker text line (if present)
        trailing_empty_lines = 0
        for line in iter(coturn_config.splitlines()):
            if line != COTURN_CONFIG_CUTOFF_MARKER:
                if len(line.strip()) == 0:
                    trailing_empty_lines += 1
                else:
                    trailing_empty_lines = 0
                _ = target.write(line + "\n")
            else:
                break

        # Write up to 2 newlines
        while trailing_empty_lines < 2:
            trailing_empty_lines += 1
            _ = target.write("\n")

        # Write the cutoff marker
        _ = target.write(COTURN_CONFIG_CUTOFF_MARKER + "\n\n")
        _ = target.write(
            "# Everything below this cutoff marker will be replaced by coturn-babysitter\n\n"
        )

        # Write the generated part of the coturn configuration
        with open(self.template_file) as template:
            for line in template:
                add_line = True
                line_derived = line

                if PLACEHOLDER_PUBLIC_IPV4_ADDR in line_derived:
                    if ipv4_addr != "":
                        line_derived = line_derived.replace(
                            PLACEHOLDER_PUBLIC_IPV4_ADDR, ipv4_addr
                        )
                    else:
                        add_line = False
                        warn(
                            "IPv4 address placeholder found in template text "
                            + "line, but no IPv4 address known: Skipping line"
                        )

                if PLACEHOLDER_PUBLIC_IPV6_ADDR in line_derived:
                    if ipv6_addr != "":
                        line_derived = line_derived.replace(
                            PLACEHOLDER_PUBLIC_IPV6_ADDR, ipv6_addr
                        )
                    else:
                        add_line = False
                        warn(
                            "IPv6 address placeholder found in template text "
                            + "line, but no IPv6 address known: Skipping line"
                        )

                if add_line:
                    _ = target.write(line_derived)

        return target.getvalue()

    def update_coturn_config(self, coturn_config_new: str) -> None:
        if self.dry_run:
            info(
                "Dry run, skipping: "
                + "coturn configuration has changed, replacing configuration file"
            )
            return

        info("coturn configuration has changed, replacing configuration file")
        if self.coturn_config_file_tmp:
            # Write the new config file to a temporary file and give it the
            # same owning user, owning group and permissions as the original
            # file
            with open(self.coturn_config_file_tmp, mode="w") as file:
                _ = file.write(coturn_config_new)
            copy_user_group_perms(
                from_file=self.coturn_config_file, to_file=self.coturn_config_file_tmp
            )
            # Back up the original coturn configuration file
            self.create_coturn_config_backup()
            # Switch out the original coturn configuration file with the
            # new one
            os.replace(self.coturn_config_file_tmp, self.coturn_config_file)
        else:
            # Back up the original coturn configuration file
            self.create_coturn_config_backup()
            # Overwrite the configuration file in-place
            with open(self.coturn_config_file, mode="w") as file:
                _ = file.write(coturn_config_new)

    def create_coturn_config_backup(self) -> None:
        if self.coturn_config_file_bak:
            _ = shutil.copy2(self.coturn_config_file, self.coturn_config_file_bak)
            copy_user_group_perms(
                from_file=self.coturn_config_file, to_file=self.coturn_config_file_bak
            )

    def reload_coturn_service(self) -> None:
        if self.dry_run:
            info(f"Dry run, skipping: Reloading '{self.coturn_service}'")
            return

        info(f"Reloading '{self.coturn_service}'")
        _ = subprocess.run(
            ["systemctl", "try-reload-or-restart", "--", self.coturn_service],
            check=True,
            text=True,
        )


def copy_user_group_perms(from_file: str, to_file: str) -> None:
    from_stat = os.stat(from_file)
    to_stat = os.stat(to_file)

    if to_stat.st_uid != from_stat.st_uid or to_stat.st_gid != from_stat.st_gid:
        shutil.chown(
            path=to_file,
            user=from_stat.st_uid,
            group=from_stat.st_gid,
        )

    if to_stat.st_mode != from_stat.st_mode:
        os.chmod(path=to_file, mode=from_stat.st_mode)


def get_ip_addr(qname: str, rdtype: RdataType) -> str:
    # We have to use dnspython instead of socket.getaddrinfo() because
    # we need an actual DNS lookup to get the public IP addresses
    # socket.getaddrinfo() also reads the HOSTS file, which won't do
    # because sometimes local services' domain names are mapped to
    # 0.0.0.0 and ::1 in the HOSTS file, and that is what
    # socket.getaddrinfo() returns then, which are very much *not*
    # public IP addresses
    answer: dns.resolver.Answer = dns.resolver.resolve(qname=qname, rdtype=rdtype)
    if len(answer) > 0:
        return str(next(iter(answer)))  # pyright: ignore[reportAny]
    else:
        return ""


def debug(message: str) -> None:
    if verbose:
        print("DEBUG " + message, file=sys.stderr)


def info(message: str) -> None:
    print(" INFO " + message, file=sys.stderr)


def warn(message: str) -> None:
    print(" WARN " + message, file=sys.stderr)


def error(message: str) -> None:
    print("ERROR " + message, file=sys.stderr)


def get_arg_parser() -> argparse.ArgumentParser:
    result = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description=DESCRIPTION,
        epilog=EPILOG,
    )

    _ = result.add_argument(
        "--version",
        help="Print the version string and exit with code 0",
        action="store_true",
    )
    _ = result.add_argument(
        "-n",
        "--domain-name",
        help="The domain name to resolve to obtain the IPv4 and IPv6 addresses",
        type=str,
        required=True,
    )
    _ = result.add_argument(
        "-t",
        "--template",
        help="The coturn-babysitter template file (default: '%(default)s')",
        default=TEMPLATE_DEFAULT,
        type=str,
    )
    _ = result.add_argument(
        "-c",
        "--coturn-config",
        help="The coturn configuration file to edit (default: '%(default)s')",
        default=COTURN_CONFIG_DEFAULT,
        type=str,
    )
    _ = result.add_argument(
        "-T",
        "--use-temp-file",
        help="Do not overwrite the coturn configuration file in-place, use a "
        + "temporary file instead (at --coturn-config + .cb-tmp)",
        action="store_true",
    )
    _ = result.add_argument(
        "-r",
        "--reload",
        help="Reload the coturn systemd service if the configuration was changed",
        action="store_true",
    )
    _ = result.add_argument(
        "-s",
        "--service",
        help="The coturn systemd service to reload on configuration change "
        + "(default: '%(default)s')",
        default=SERVICE_DEFAULT,
        type=str,
    )
    _ = result.add_argument(
        "-d",
        "--dry-run",
        help="Simulation run (no configuration replacing, no service reloading)",
        action="store_true",
    )
    _ = result.add_argument(
        "-v",
        "--verbose",
        help="Print DEBUG messages",
        action="store_true",
    )
    return result


def main(args: argparse.Namespace) -> int:
    """Extracts the options from the given ArgumentParser Namespace and
    runs CoturnBabysitter with those options"""
    global verbose
    try:
        verbose = bool(args.verbose)
        coturn_config_file = str(args.coturn_config)
        use_temp_file = bool(args.use_temp_file)

        if use_temp_file:
            coturn_config_file_tmp = coturn_config_file + ".cb-tmp"
        else:
            coturn_config_file_tmp = ""

        singleton = CoturnBabysitter(
            domain_name=str(args.domain_name),
            template_file=str(args.template),
            coturn_config_file=coturn_config_file,
            coturn_config_file_tmp=coturn_config_file_tmp,
            coturn_config_file_bak=str(args.coturn_config) + ".cb-bak",
            coturn_service=str(args.service),
            reload=bool(args.reload),
            dry_run=bool(args.dry_run),
        )
        singleton.run()
        return 0
    except Exception as exception:
        error(f"{exception}")
        return 1


if __name__ == "__main__":
    parser = get_arg_parser()

    # Handle no arguments: Print the help text and exit with failure
    # code
    if len(sys.argv) <= 1:
        parser.print_help(sys.__stdout__)
        sys.exit(1)

    # Handle --version
    if len(sys.argv) == 2 and sys.argv[1] == "--version":
        print(VERSION_STR)
        sys.exit(0)

    args = parser.parse_args()
    sys.exit(main(args))