303 lines
8.8 KiB
Python
303 lines
8.8 KiB
Python
"""Generate a Go package from thrift-files/meowment-network."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
from pathlib import Path
|
|
from typing import Iterable, Sequence
|
|
|
|
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
DEFAULT_SOURCE_DIR = REPO_ROOT / "thrift-files" / "meowment-network"
|
|
DEFAULT_OUTPUT_DIR = REPO_ROOT / "compiled_output" / "golang" / "meowmentnet"
|
|
DEFAULT_PACKAGE_NAME = "meowmentnet"
|
|
DEFAULT_THRIFT_IMPORT = "github.com/apache/thrift/lib/go/thrift"
|
|
|
|
|
|
class ToolError(RuntimeError):
|
|
pass
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate a Go package from meowment-network thrift files.",
|
|
)
|
|
parser.add_argument(
|
|
"--source-dir",
|
|
type=Path,
|
|
default=DEFAULT_SOURCE_DIR,
|
|
help="Directory containing .thrift files.",
|
|
)
|
|
parser.add_argument(
|
|
"--output-dir",
|
|
type=Path,
|
|
default=DEFAULT_OUTPUT_DIR,
|
|
help="Destination directory of the generated Go package.",
|
|
)
|
|
parser.add_argument(
|
|
"--package-name",
|
|
default=DEFAULT_PACKAGE_NAME,
|
|
help="Generated Go package name.",
|
|
)
|
|
parser.add_argument(
|
|
"--compiler",
|
|
type=Path,
|
|
help="Explicit thrift compiler executable path.",
|
|
)
|
|
parser.add_argument(
|
|
"--package-prefix",
|
|
default="",
|
|
help="Optional value passed to go:package_prefix=...",
|
|
)
|
|
parser.add_argument(
|
|
"--thrift-import",
|
|
default=DEFAULT_THRIFT_IMPORT,
|
|
help="Import path of the Go thrift runtime.",
|
|
)
|
|
parser.add_argument(
|
|
"--module",
|
|
default="",
|
|
help="If set, write a go.mod file with this module path.",
|
|
)
|
|
parser.add_argument(
|
|
"--no-clean",
|
|
action="store_true",
|
|
help="Do not delete the existing output directory before copying files.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-gofmt",
|
|
action="store_true",
|
|
help="Do not run gofmt after generation.",
|
|
)
|
|
parser.add_argument(
|
|
"--skip-go-mod-tidy",
|
|
action="store_true",
|
|
help="Do not run go mod tidy after writing go.mod.",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def candidate_compilers(explicit: Path | None) -> Iterable[Path]:
|
|
seen: set[str] = set()
|
|
|
|
def emit(path: Path | None) -> Iterable[Path]:
|
|
if path is None:
|
|
return ()
|
|
resolved = str(path)
|
|
if resolved in seen:
|
|
return ()
|
|
seen.add(resolved)
|
|
return (path,)
|
|
|
|
if explicit is not None:
|
|
yield from emit(explicit)
|
|
|
|
env_compiler = Path(os.environ["THRIFT_COMPILER"]) if "THRIFT_COMPILER" in os.environ else None
|
|
yield from emit(env_compiler)
|
|
yield from emit(REPO_ROOT / "compiler" / "exe" / "thrift-go.exe")
|
|
|
|
system_compiler = shutil.which("thrift")
|
|
if system_compiler:
|
|
yield from emit(Path(system_compiler))
|
|
|
|
|
|
def probe_compiler(executable: Path) -> bool:
|
|
try:
|
|
completed = subprocess.run(
|
|
[str(executable), "-version"],
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=10,
|
|
)
|
|
except OSError:
|
|
return False
|
|
except subprocess.TimeoutExpired:
|
|
return False
|
|
return completed.returncode == 0
|
|
|
|
|
|
def resolve_compiler(explicit: Path | None) -> Path:
|
|
for compiler in candidate_compilers(explicit):
|
|
if probe_compiler(compiler):
|
|
return compiler
|
|
raise ToolError(
|
|
"No working thrift compiler was found. Use --compiler or set THRIFT_COMPILER."
|
|
)
|
|
|
|
|
|
def find_thrift_files(source_dir: Path) -> list[Path]:
|
|
if not source_dir.exists():
|
|
raise ToolError(f"Source directory does not exist: {source_dir}")
|
|
thrift_files = sorted(source_dir.glob("*.thrift"))
|
|
if not thrift_files:
|
|
raise ToolError(f"No .thrift files were found in: {source_dir}")
|
|
return thrift_files
|
|
|
|
|
|
def build_generator_option(args: argparse.Namespace) -> str:
|
|
options = [f"package={args.package_name}"]
|
|
options.append("ignore_initialisms")
|
|
if args.package_prefix:
|
|
options.append(f"package_prefix={normalize_package_prefix(args.package_prefix)}")
|
|
if args.thrift_import:
|
|
options.append(f"thrift_import={args.thrift_import}")
|
|
options.append("skip_remote")
|
|
return "go:" + ",".join(options)
|
|
|
|
|
|
def normalize_package_prefix(prefix: str) -> str:
|
|
normalized = prefix.strip()
|
|
if not normalized:
|
|
return normalized
|
|
return normalized if normalized.endswith("/") else normalized + "/"
|
|
|
|
|
|
def run_command(command: Sequence[str], cwd: Path) -> None:
|
|
completed = subprocess.run(
|
|
list(command),
|
|
cwd=str(cwd),
|
|
check=False,
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
if completed.returncode == 0:
|
|
return
|
|
|
|
stdout = completed.stdout.strip()
|
|
stderr = completed.stderr.strip()
|
|
details = "\n".join(part for part in (stdout, stderr) if part)
|
|
if not details:
|
|
details = "The command exited with a non-zero status and produced no output."
|
|
raise ToolError(details)
|
|
|
|
|
|
def generate_go_sources(
|
|
compiler: Path,
|
|
source_dir: Path,
|
|
thrift_files: Sequence[Path],
|
|
package_name: str,
|
|
generator_option: str,
|
|
staging_root: Path,
|
|
) -> Path:
|
|
for thrift_file in thrift_files:
|
|
command = [
|
|
str(compiler),
|
|
"-o",
|
|
str(staging_root),
|
|
"-strict",
|
|
"--gen",
|
|
generator_option,
|
|
str(thrift_file),
|
|
]
|
|
try:
|
|
run_command(command, cwd=source_dir)
|
|
except ToolError as exc:
|
|
raise ToolError(f"Failed to generate {thrift_file.name}:\n{exc}") from exc
|
|
|
|
package_dir = staging_root / "gen-go" / package_name
|
|
if not package_dir.exists():
|
|
raise ToolError(f"Generated package directory not found: {package_dir}")
|
|
|
|
return package_dir
|
|
|
|
|
|
def write_go_mod(output_dir: Path, module_name: str) -> None:
|
|
go_mod = output_dir / "go.mod"
|
|
go_mod.write_text(
|
|
"\n".join(
|
|
[
|
|
f"module {module_name}",
|
|
"",
|
|
"go 1.22",
|
|
"",
|
|
"require github.com/apache/thrift v0.22.0",
|
|
"",
|
|
]
|
|
),
|
|
encoding="utf-8",
|
|
)
|
|
|
|
|
|
def copy_output(staged_package_dir: Path, output_dir: Path, clean: bool) -> None:
|
|
if clean and output_dir.exists():
|
|
shutil.rmtree(output_dir)
|
|
output_dir.parent.mkdir(parents=True, exist_ok=True)
|
|
shutil.copytree(staged_package_dir, output_dir, dirs_exist_ok=not clean)
|
|
|
|
|
|
def run_gofmt(output_dir: Path) -> None:
|
|
gofmt = shutil.which("gofmt")
|
|
if not gofmt:
|
|
print("[warn] gofmt was not found in PATH. Skipping formatting.")
|
|
return
|
|
|
|
go_files = sorted(output_dir.glob("*.go"))
|
|
if not go_files:
|
|
return
|
|
|
|
for go_file in go_files:
|
|
run_command([gofmt, "-w", str(go_file)], cwd=output_dir)
|
|
|
|
|
|
def run_go_mod_tidy(output_dir: Path) -> None:
|
|
go = shutil.which("go")
|
|
if not go:
|
|
print("[warn] go was not found in PATH. Skipping go mod tidy.")
|
|
return
|
|
|
|
run_command([go, "mod", "tidy"], cwd=output_dir)
|
|
|
|
|
|
def validate_args(args: argparse.Namespace) -> None:
|
|
if not args.package_name.strip():
|
|
raise ToolError("--package-name cannot be empty.")
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
try:
|
|
validate_args(args)
|
|
source_dir = args.source_dir.resolve()
|
|
output_dir = args.output_dir.resolve()
|
|
thrift_files = find_thrift_files(source_dir)
|
|
compiler = resolve_compiler(args.compiler.resolve() if args.compiler else None)
|
|
generator_option = build_generator_option(args)
|
|
with tempfile.TemporaryDirectory(prefix="thrift_go_") as tmp_dir:
|
|
staged_package_dir = generate_go_sources(
|
|
compiler=compiler,
|
|
source_dir=source_dir,
|
|
thrift_files=thrift_files,
|
|
package_name=args.package_name,
|
|
generator_option=generator_option,
|
|
staging_root=Path(tmp_dir),
|
|
)
|
|
copy_output(staged_package_dir, output_dir, clean=not args.no_clean)
|
|
if args.module:
|
|
write_go_mod(output_dir, args.module)
|
|
if not args.skip_go_mod_tidy:
|
|
run_go_mod_tidy(output_dir)
|
|
if not args.skip_gofmt:
|
|
run_gofmt(output_dir)
|
|
|
|
print(f"Compiler: {compiler}")
|
|
print(f"Source directory: {source_dir}")
|
|
print(f"Output directory: {output_dir}")
|
|
print(f"Package name: {args.package_name}")
|
|
print(f"Thrift files: {len(thrift_files)}")
|
|
if args.module:
|
|
print(f"go.mod: {args.module}")
|
|
return 0
|
|
except ToolError as exc:
|
|
print(f"[error] {exc}", file=sys.stderr)
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|