diff --git a/.github/actions/set-up-legacy-python/action.yml b/.github/actions/set-up-legacy-python/action.yml index 791c7db..54a209b 100644 --- a/.github/actions/set-up-legacy-python/action.yml +++ b/.github/actions/set-up-legacy-python/action.yml @@ -22,6 +22,6 @@ runs: command: pip install -U pip - name: Install dependencies run: | - python -m pip install flake8 pytest setuptools wheel + python -m pip install flake8 pytest setuptools wheel pydantic pyyaml if [ -f requirements.txt ]; then pip install -r requirements.txt; fi shell: bash diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a00243e..4dd41ae 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -29,7 +29,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pyright pyflakes pytest setuptools wheel + python -m pip install flake8 pyright pyflakes pytest setuptools wheel pydantic pyyaml if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - name: Lint with flake8 run: | diff --git a/cf_remote/commands.py b/cf_remote/commands.py index 53bc9a4..f96bfb0 100644 --- a/cf_remote/commands.py +++ b/cf_remote/commands.py @@ -3,6 +3,7 @@ import subprocess import sys import time +import yaml from multiprocessing.dummy import Pool from cf_remote.remote import ( @@ -50,6 +51,13 @@ from cf_remote.spawn import spawn_vms, destroy_vms, dump_vms_info, get_cloud_driver from cf_remote import log from cf_remote import cloud_data +from cf_remote.up import validate_config +from cf_remote.validate import ( + validate_aws_credentials, + validate_gcp_credentials, + validate_aws_image, + validate_vagrant_box, +) def info(hosts, users=None): @@ -437,10 +445,6 @@ def spawn( vagrant_sync_folder=None, vagrant_provision=None, ): - creds_data = None - if os.path.exists(CLOUD_CONFIG_FPATH): - creds_data = read_json(CLOUD_CONFIG_FPATH) - vms_info = None if os.path.exists(CLOUD_STATE_FPATH): vms_info = read_json(CLOUD_STATE_FPATH) @@ -457,39 +461,12 @@ def spawn( sec_groups = None key_pair = None if provider == Providers.AWS: - if not creds_data: - raise CFRUserError( - "Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH - ) - try: - creds = _get_aws_creds_from_env() or AWSCredentials( - creds_data["aws"]["key"], - creds_data["aws"]["secret"], - creds_data["aws"].get("token", ""), - ) - sec_groups = creds_data["aws"]["security_groups"] - key_pair = creds_data["aws"]["key_pair"] - except KeyError: - print("Incomplete AWS credential info") # TODO: report missing keys - return 1 - - region = region or creds_data["aws"].get("region", "eu-west-1") + creds, region, sec_groups, key_pair = validate_aws_credentials() + validate_aws_image(platform) elif provider == Providers.GCP: - if not creds_data: - raise CFRUserError( - "Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH - ) - try: - creds = GCPCredentials( - creds_data["gcp"]["project_id"], - creds_data["gcp"]["service_account_id"], - creds_data["gcp"]["key_path"], - ) - except KeyError: - print("Incomplete GCP credential info") # TODO: report missing keys - return 1 - - region = region or creds_data["gcp"].get("region", "europe-west1-b") + creds, region = validate_gcp_credentials() + else: + validate_vagrant_box(platform) # TODO: Do we have to complicate this instead of just assuming existing VMs # were created by this code and thus follow the naming pattern from this @@ -1086,3 +1063,18 @@ def connect_cmd(hosts): print("") log.error("The ssh command exited with error code " + str(r.returncode)) return r.returncode + + +def up_cmd(config_path): + content = None + try: + with open(config_path, "r") as f: + content = yaml.safe_load(f) + except yaml.YAMLError: + raise CFRUserError("'%s' is not a valid YAML file" % config_path) + except FileNotFoundError: + raise CFRUserError("'%s' doesn't exist" % config_path) + + validate_config(content) + + return 0 diff --git a/cf_remote/main.py b/cf_remote/main.py index 2ff1026..ef3f8ad 100644 --- a/cf_remote/main.py +++ b/cf_remote/main.py @@ -318,6 +318,11 @@ def _get_arg_parser(): "--hosts", "-H", help="Host to open the shell on", type=str, required=True ) + sp = subp.add_parser("up", help="Run cf-remote from a yaml config") + sp.add_argument( + "config", help="Path to yaml config", default="config.yaml", nargs="?" + ) + return ap @@ -455,6 +460,8 @@ def run_command_with_args(command, args) -> int: return commands.agent(args.hosts, args.bootstrap) elif command == "connect": return commands.connect_cmd(args.hosts) + elif command == "up": + return commands.up_cmd(args.config) else: raise CFRExitError("Unknown command: '{}'".format(command)) diff --git a/cf_remote/spawn.py b/cf_remote/spawn.py index a7e3b25..cef39a5 100644 --- a/cf_remote/spawn.py +++ b/cf_remote/spawn.py @@ -22,7 +22,6 @@ from cf_remote.paths import cf_remote_dir, CLOUD_STATE_FPATH from cf_remote.utils import CFRUserError, whoami, copy_file, canonify, read_json from cf_remote import log -from cf_remote import cloud_data VAGRANT_VM_IP_START = "192.168.56.9" _NAME_RANDOM_PART_LENGTH = 4 @@ -414,12 +413,6 @@ def spawn_vm_in_aws( size=None, role=None, ): - platform_name = platform.split("-")[0] - if platform_name not in aws_image_criteria: - raise ValueError( - "Platform '%s' is not in our set of image criteria. (Available platforms: %s)" - % (platform, ", ".join(cloud_data.aws_image_criteria.keys())) - ) try: driver = get_cloud_driver(Providers.AWS, aws_creds, region) existing_vms = driver.list_nodes() @@ -774,6 +767,7 @@ def spawn_vm_in_vagrant( sync_folder=None, provision_script=None, ): + name = canonify(name).replace("_", "-") vagrantdir = cf_remote_dir(os.path.join("vagrant", name)) os.makedirs(vagrantdir, exist_ok=True) diff --git a/cf_remote/up.py b/cf_remote/up.py new file mode 100644 index 0000000..95a95ef --- /dev/null +++ b/cf_remote/up.py @@ -0,0 +1,226 @@ +from pydantic import BaseModel, model_validator, ValidationError, Field +from typing import Union, Literal, Optional, List, Annotated +from functools import reduce + +from cf_remote.utils import CFRUserError +from cf_remote import log + +import cf_remote.validate as validate + + +# Forces pydantic to throw validation error if config contains unknown keys +class NoExtra(BaseModel, extra="forbid"): + pass + + +class Config(NoExtra): + pass + + +class AWSConfig(Config): + image: str + size: Literal["micro", "xlarge"] = "micro" + + @model_validator(mode="after") + def check_aws_config(self): + validate.validate_aws_image(self.image) + return self + + +class VagrantConfig(Config): + box: str + memory: int = 512 + cpus: int = 1 + sync_folder: Optional[str] = None + provision: Optional[str] = None + + @model_validator(mode="after") + def check_vagrant_config(self): + if self.memory < 512: + raise CFRUserError("Cannot allocate less than 512MB to a Vagrant VM") + if self.cpus < 1: + raise CFRUserError("Cannot use less than 1 cpu per Vagrant VM") + + validate.validate_vagrant_box(self.box) + + return self + + +class GCPConfig(Config): + image: str # There is no list of avalaible GCP platforms to validate against yet + network: Optional[str] = None + public_ip: bool = True + size: str = "n1-standard-1" + + +class AWSProvider(Config): + provider: Literal["aws"] + aws: AWSConfig + + @model_validator(mode="after") + def check_aws_provider(self): + validate.validate_aws_credentials() + return self + + +class GCPProvider(Config): + provider: Literal["gcp"] + gcp: GCPConfig + + @model_validator(mode="after") + def check_gcp_provider(self): + validate.validate_gcp_credentials() + return self + + +class VagrantProvider(Config): + provider: Literal["vagrant"] + vagrant: VagrantConfig + + +class SaveMode(Config): + mode: Literal["save"] + hosts: List[str] + + +class SpawnMode(Config): + mode: Literal["spawn"] + # "Field" forces pydantic to report errors on the branch defined by the field "provider" + spawn: Annotated[ + Union[VagrantProvider, AWSProvider, GCPProvider], + Field(discriminator="provider"), + ] + count: int + + @model_validator(mode="after") + def check_spawn_config(self): + if self.count < 1: + raise CFRUserError("Cannot spawn less than 1 instance") + return self + + +class CFEngineConfig(Config): + version: Optional[str] = None + bootstrap: Optional[str] = None + edition: Literal["community", "enterprise"] = "enterprise" + remote_download: bool = False + hub_package: Optional[str] = None + client_package: Optional[str] = None + package: Optional[str] = None + demo: bool = False + + @model_validator(mode="after") + def check_cfengine_config(self): + packages = [self.package, self.hub_package, self.client_package] + for p in packages: + validate.validate_package(p, self.remote_download) + + if self.version and any(packages): + log.warning("Specifying package overrides cfengine version") + + validate.validate_version(self.version, self.edition) + validate.validate_state_bootstrap(self.bootstrap) + + return self + + +class GroupConfig(Config): + role: Literal["client", "hub"] + # "Field" forces pydantic to report errors on the branch defined by the field "provider" + source: Annotated[Union[SaveMode, SpawnMode], Field(discriminator="mode")] + cfengine: Optional[CFEngineConfig] = None + scripts: Optional[List[str]] = None + + @model_validator(mode="after") + def check_group_config(self): + if ( + self.role == "hub" + and self.source.mode == "spawn" + and self.source.count != 1 + ): + raise CFRUserError("A hub can only have one host") + + return self + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return reduce(_getattr, [obj] + attr.split(".")) + + +class Group: + """ + All group-specific data: + - Vagrantfile + Config that declares it: + - provider, count, cfengine version, role, ... + """ + + def __init__(self, config: GroupConfig): + self.config = config + self.hosts = [] + + +class Host: + """ + All host-specific data: + - user, ip, ssh config, OS, uuid, ... + """ + + def __init__(self): + pass + + +def _resolve_templates(parent, templates): + if not parent: + return + if isinstance(parent, dict): + for key, value in parent.items(): + if isinstance(value, str) and value in templates: + parent[key] = templates[value] + else: + _resolve_templates(value, templates) + if isinstance(parent, list): + for value in parent: + _resolve_templates(value, templates) + + +def validate_config(content): + if not content: + raise CFRUserError("Empty spawn config") + + if "groups" not in content: + raise CFRUserError("Missing 'groups' key in spawn config") + + groups = content["groups"] + templates = content.get("templates") + if templates: + _resolve_templates(groups, templates) + + if not isinstance(groups, list): + groups = [groups] + + state = {} + try: + for g in groups: + if len(g) != 1: + raise CFRUserError( + "Too many keys in group definition: {}".format( + ", ".join(list(g.keys())) + ) + ) + + for k, v in g.items(): + state[k] = Group(GroupConfig(**v)) + + except ValidationError as v: + msgs = [] + for err in v.errors(): + msgs.append( + "{}. Input '{}' at location '{}'".format( + err["msg"], err["input"], err["loc"] + ) + ) + raise CFRUserError("\n".join(msgs)) diff --git a/cf_remote/validate.py b/cf_remote/validate.py new file mode 100644 index 0000000..647f85c --- /dev/null +++ b/cf_remote/validate.py @@ -0,0 +1,133 @@ +from cf_remote.paths import CLOUD_CONFIG_FPATH, CLOUD_STATE_FPATH +from cf_remote.utils import read_json, CFRUserError, is_package_url +from cf_remote.spawn import ( + Providers, + get_cloud_driver, + InvalidCredsError, + AWSCredentials, + GCPCredentials, +) +from cf_remote.cloud_data import aws_image_criteria +from cf_remote.remote import Releases + +import os +import subprocess + + +def validate_state_bootstrap(bootstrap): + state = read_json(CLOUD_STATE_FPATH) + if state is None: + return + key = "@{}".format(bootstrap) + + # TODO: Change how to check this if cloud_state.json changes format + if key in state and state[key].values()[1]["role"] != "hub": + raise CFRUserError("Cannot bootstrap to an existing host that is not a hub") + + +def validate_package(package, remote_download=False): + if package is None: + return + + if remote_download and not is_package_url(package): + raise CFRUserError("Package '{}' is not a valid package URL") + + +def validate_version(version, edition): + releases = Releases(edition) + release = releases.default + if version: + release = releases.pick_version(version) + if release is None: + raise CFRUserError( + "Could not find a release for version {}. The supported versions are {}".format( + version, releases + ) + ) + + return release + + +def validate_vagrant_box(box): + ret = subprocess.run(["vagrant", "box", "list"], capture_output=True, text=True) + box_list = [ + line.split()[0] for line in ret.stdout.split("\n") if len(line.split()) > 0 + ] + + if box not in box_list: + raise CFRUserError("Box '{}' is not installed or doesn't exist".format(box)) + + +def validate_aws_image(platform): + platform_name = platform.split("-")[0] + if platform_name not in aws_image_criteria: + raise CFRUserError( + "Platform '%s' is not in our set of image criteria. (Available platforms: %s)" + % (platform, ", ".join(aws_image_criteria.keys())) + ) + + +def _get_aws_creds_from_env(): + if "AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ: + return AWSCredentials( + os.environ["AWS_ACCESS_KEY_ID"], + os.environ["AWS_SECRET_ACCESS_KEY"], + os.environ.get("AWS_SESSION_TOKEN", ""), + ) + return None + + +def validate_aws_credentials(): + creds_data = None + if os.path.exists(CLOUD_CONFIG_FPATH): + creds_data = read_json(CLOUD_CONFIG_FPATH) + + if not creds_data: + raise CFRUserError("Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH) + creds = None + try: + creds = _get_aws_creds_from_env() or AWSCredentials( + creds_data["aws"]["key"], + creds_data["aws"]["secret"], + creds_data["aws"].get("token", ""), + ) + except KeyError: + raise CFRUserError( + "Incomplete AWS credential info" + ) # TODO: report missing keys + + region = creds_data["aws"].get("region", "eu-west-1") + sec_groups = creds_data["aws"]["security_groups"] + key_pair = creds_data["aws"]["key_pair"] + + if creds: + try: + get_cloud_driver(Providers.AWS, creds, region) + except InvalidCredsError as error: + raise CFRUserError( + "Invalid credentials, check cloud_config.json (%s.)" % str(error)[1:-1] + ) + return creds, region, sec_groups, key_pair + + +def validate_gcp_credentials(): + creds_data = None + if os.path.exists(CLOUD_CONFIG_FPATH): + creds_data = read_json(CLOUD_CONFIG_FPATH) + + if not creds_data: + raise CFRUserError("Cloud configuration not found at %s" % CLOUD_CONFIG_FPATH) + try: + creds = GCPCredentials( + creds_data["gcp"]["project_id"], + creds_data["gcp"]["service_account_id"], + creds_data["gcp"]["key_path"], + ) + except KeyError: + raise CFRUserError( + "Incomplete AWS credential info" + ) # TODO: report missing keys + + region = creds_data["gcp"].get("region", "europe-west1-b") + + return creds, region