"""Generate a Go package from thrift-files/meowment-network.""" from __future__ import annotations import argparse import os import shutil import subprocess import sys import tempfile from pathlib import Path from typing import Iterable, Sequence REPO_ROOT = Path(__file__).resolve().parents[2] DEFAULT_SOURCE_DIR = REPO_ROOT / "thrift-files" / "meowment-network" DEFAULT_OUTPUT_DIR = REPO_ROOT / "compiled_output" / "golang" / "meowmentnet" DEFAULT_PACKAGE_NAME = "meowmentnet" DEFAULT_THRIFT_IMPORT = "github.com/apache/thrift/lib/go/thrift" class ToolError(RuntimeError): pass def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Generate a Go package from meowment-network thrift files.", ) parser.add_argument( "--source-dir", type=Path, default=DEFAULT_SOURCE_DIR, help="Directory containing .thrift files.", ) parser.add_argument( "--output-dir", type=Path, default=DEFAULT_OUTPUT_DIR, help="Destination directory of the generated Go package.", ) parser.add_argument( "--package-name", default=DEFAULT_PACKAGE_NAME, help="Generated Go package name.", ) parser.add_argument( "--compiler", type=Path, help="Explicit thrift compiler executable path.", ) parser.add_argument( "--package-prefix", default="", help="Optional value passed to go:package_prefix=...", ) parser.add_argument( "--thrift-import", default=DEFAULT_THRIFT_IMPORT, help="Import path of the Go thrift runtime.", ) parser.add_argument( "--module", default="", help="If set, write a go.mod file with this module path.", ) parser.add_argument( "--no-clean", action="store_true", help="Do not delete the existing output directory before copying files.", ) parser.add_argument( "--skip-gofmt", action="store_true", help="Do not run gofmt after generation.", ) parser.add_argument( "--skip-go-mod-tidy", action="store_true", help="Do not run go mod tidy after writing go.mod.", ) return parser.parse_args() def candidate_compilers(explicit: Path | None) -> Iterable[Path]: seen: set[str] = set() def emit(path: Path | None) -> Iterable[Path]: if path is None: return () resolved = str(path) if resolved in seen: return () seen.add(resolved) return (path,) if explicit is not None: yield from emit(explicit) env_compiler = Path(os.environ["THRIFT_COMPILER"]) if "THRIFT_COMPILER" in os.environ else None yield from emit(env_compiler) yield from emit(REPO_ROOT / "compiler" / "exe" / "thrift-go.exe") system_compiler = shutil.which("thrift") if system_compiler: yield from emit(Path(system_compiler)) def probe_compiler(executable: Path) -> bool: try: completed = subprocess.run( [str(executable), "-version"], check=False, capture_output=True, text=True, timeout=10, ) except OSError: return False except subprocess.TimeoutExpired: return False return completed.returncode == 0 def resolve_compiler(explicit: Path | None) -> Path: for compiler in candidate_compilers(explicit): if probe_compiler(compiler): return compiler raise ToolError( "No working thrift compiler was found. Use --compiler or set THRIFT_COMPILER." ) def find_thrift_files(source_dir: Path) -> list[Path]: if not source_dir.exists(): raise ToolError(f"Source directory does not exist: {source_dir}") thrift_files = sorted(source_dir.glob("*.thrift")) if not thrift_files: raise ToolError(f"No .thrift files were found in: {source_dir}") return thrift_files def build_generator_option(args: argparse.Namespace) -> str: options = [f"package={args.package_name}"] options.append("ignore_initialisms") if args.package_prefix: options.append(f"package_prefix={normalize_package_prefix(args.package_prefix)}") if args.thrift_import: options.append(f"thrift_import={args.thrift_import}") options.append("skip_remote") return "go:" + ",".join(options) def normalize_package_prefix(prefix: str) -> str: normalized = prefix.strip() if not normalized: return normalized return normalized if normalized.endswith("/") else normalized + "/" def run_command(command: Sequence[str], cwd: Path) -> None: completed = subprocess.run( list(command), cwd=str(cwd), check=False, capture_output=True, text=True, ) if completed.returncode == 0: return stdout = completed.stdout.strip() stderr = completed.stderr.strip() details = "\n".join(part for part in (stdout, stderr) if part) if not details: details = "The command exited with a non-zero status and produced no output." raise ToolError(details) def generate_go_sources( compiler: Path, source_dir: Path, thrift_files: Sequence[Path], package_name: str, generator_option: str, staging_root: Path, ) -> Path: for thrift_file in thrift_files: command = [ str(compiler), "-o", str(staging_root), "-strict", "--gen", generator_option, str(thrift_file), ] try: run_command(command, cwd=source_dir) except ToolError as exc: raise ToolError(f"Failed to generate {thrift_file.name}:\n{exc}") from exc package_dir = staging_root / "gen-go" / package_name if not package_dir.exists(): raise ToolError(f"Generated package directory not found: {package_dir}") return package_dir def write_go_mod(output_dir: Path, module_name: str) -> None: go_mod = output_dir / "go.mod" go_mod.write_text( "\n".join( [ f"module {module_name}", "", "go 1.22", "", "require github.com/apache/thrift v0.22.0", "", ] ), encoding="utf-8", ) def copy_output(staged_package_dir: Path, output_dir: Path, clean: bool) -> None: if clean and output_dir.exists(): shutil.rmtree(output_dir) output_dir.parent.mkdir(parents=True, exist_ok=True) shutil.copytree(staged_package_dir, output_dir, dirs_exist_ok=not clean) def run_gofmt(output_dir: Path) -> None: gofmt = shutil.which("gofmt") if not gofmt: print("[warn] gofmt was not found in PATH. Skipping formatting.") return go_files = sorted(output_dir.glob("*.go")) if not go_files: return for go_file in go_files: run_command([gofmt, "-w", str(go_file)], cwd=output_dir) def run_go_mod_tidy(output_dir: Path) -> None: go = shutil.which("go") if not go: print("[warn] go was not found in PATH. Skipping go mod tidy.") return run_command([go, "mod", "tidy"], cwd=output_dir) def validate_args(args: argparse.Namespace) -> None: if not args.package_name.strip(): raise ToolError("--package-name cannot be empty.") def main() -> int: args = parse_args() try: validate_args(args) source_dir = args.source_dir.resolve() output_dir = args.output_dir.resolve() thrift_files = find_thrift_files(source_dir) compiler = resolve_compiler(args.compiler.resolve() if args.compiler else None) generator_option = build_generator_option(args) with tempfile.TemporaryDirectory(prefix="thrift_go_") as tmp_dir: staged_package_dir = generate_go_sources( compiler=compiler, source_dir=source_dir, thrift_files=thrift_files, package_name=args.package_name, generator_option=generator_option, staging_root=Path(tmp_dir), ) copy_output(staged_package_dir, output_dir, clean=not args.no_clean) if args.module: write_go_mod(output_dir, args.module) if not args.skip_go_mod_tidy: run_go_mod_tidy(output_dir) if not args.skip_gofmt: run_gofmt(output_dir) print(f"Compiler: {compiler}") print(f"Source directory: {source_dir}") print(f"Output directory: {output_dir}") print(f"Package name: {args.package_name}") print(f"Thrift files: {len(thrift_files)}") if args.module: print(f"go.mod: {args.module}") return 0 except ToolError as exc: print(f"[error] {exc}", file=sys.stderr) return 1 if __name__ == "__main__": raise SystemExit(main())