#!/usr/bin/env python3
"""
vcf_make_ref.py
Create a VCF-consistent reference FASTA by writing VCF REF alleles into a filler reference.

SNP-only v1:
- skips indels and multi-base alleles
- does NOT normalize indels
- supports .vcf and .vcf.gz
- supports optional --bound start,stop (inclusive; VCF POS is 1-based)
- supports optional --chrom chr1,chr2,...

Typical use for simulation:
  1) Take any reference/filler FASTA with correct contig length (can be real genome or all-N).
  2) Run this script to overwrite bases at SNP positions with VCF.REF.
  3) Feed the resulting FASTA into vcf2mig.py --ref ... to build finite-mutation sequences.

This avoids relying on external FASTA being consistent with simulated VCF allele states.
"""

import sys
import gzip
import argparse


def open_text_maybe_gzip(path):
    if path.endswith(".gz"):
        return gzip.open(path, "rt", encoding="ascii", errors="replace")
    return open(path, "rt", encoding="ascii", errors="replace")


def fasta_read_all(path):
    """
    Read multi-FASTA into:
      order: list of contig IDs in file order
      headers: dict contig_id -> full header line (without trailing newline)
      seqs: dict contig_id -> list of chars (mutable)
    contig_id is the first token after '>'.
    """
    order = []
    headers = {}
    seqs = {}

    with open_text_maybe_gzip(path) as f:
        cur_id = None
        cur_header = None
        chunks = []

        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                # flush previous
                if cur_id is not None:
                    s = "".join(chunks)
                    order.append(cur_id)
                    headers[cur_id] = cur_header
                    seqs[cur_id] = list(s)
                # start new
                cur_header = line
                cur_id = line[1:].strip().split()[0]
                chunks = []
            else:
                chunks.append(line)

        # flush last
        if cur_id is not None:
            s = "".join(chunks)
            order.append(cur_id)
            headers[cur_id] = cur_header
            seqs[cur_id] = list(s)

    if not order:
        raise ValueError(f"No FASTA records found in {path}")

    return order, headers, seqs


def fasta_write_all(path, order, headers, seqs, wrap=60):
    with open(path, "wt", encoding="ascii") as out:
        for cid in order:
            out.write(headers[cid].rstrip() + "\n")
            s = "".join(seqs[cid])
            if wrap and wrap > 0:
                for i in range(0, len(s), wrap):
                    out.write(s[i:i+wrap] + "\n")
            else:
                out.write(s + "\n")


def norm_chrom(x):
    # conservative normalization: strip leading 'chr' only
    if x.lower().startswith("chr"):
        return x[3:]
    return x


def parse_bound(s):
    a, b = s.split(",", 1)
    start = int(a)
    stop = int(b)
    if start > stop:
        start, stop = stop, start
    return start, stop


def vcf_iter_records(vcf_path):
    """
    Yield tuples (chrom, pos, ref, alts_list) from VCF.
    pos is int (VCF POS, 1-based).
    """
    with open_text_maybe_gzip(vcf_path) as f:
        for line in f:
            if not line or line[0] == "#":
                continue
            parts = line.rstrip("\n").split("\t")
            if len(parts) < 5:
                continue
            chrom = parts[0]
            pos = int(parts[1])
            ref = parts[3]
            alt = parts[4]
            alts = alt.split(",") if alt else []
            yield chrom, pos, ref, alts


def choose_contig_mapping(vcf_chroms, fasta_contigs):
    """
    Return mapping dict vcf_chrom -> fasta_contig_id.

    Strategy:
    1) exact matches
    2) normalized matches stripping 'chr'
    3) if both sides have exactly one contig, map them (warn)
    """
    mapping = {}

    fasta_set = set(fasta_contigs)
    fasta_norm = {norm_chrom(c): c for c in fasta_contigs}

    # exact / normalized
    for vc in vcf_chroms:
        if vc in fasta_set:
            mapping[vc] = vc
            continue
        vcn = norm_chrom(vc)
        if vcn in fasta_norm:
            mapping[vc] = fasta_norm[vcn]
            continue

    # single-contig fallback
    if len(fasta_contigs) == 1 and len(vcf_chroms) == 1:
        vc = next(iter(vcf_chroms))
        if vc not in mapping:
            fc = fasta_contigs[0]
            print(
                f"Warning: VCF chrom '{vc}' does not match FASTA contig '{fc}', "
                f"but each file has exactly one contig — mapping them anyway.",
                file=sys.stderr
            )
            mapping[vc] = fc

    return mapping


def main(argv):
    ap = argparse.ArgumentParser(
        prog="vcf_make_ref.py",
        description="Write VCF REF alleles into a filler reference FASTA (SNP-only v1)."
    )
    ap.add_argument("--vcf", required=True, help="Input VCF (.vcf or .vcf.gz)")
    ap.add_argument("--fasta", required=True, help="Input FASTA filler (.fa/.fasta; gz ok)")
    ap.add_argument("--out", required=True, help="Output FASTA path")
    ap.add_argument("--bound", default=None, help="start,stop (inclusive; VCF POS coordinates)")
    ap.add_argument("--chrom", default=None, help="comma-separated list of VCF CHROM values to include")
    ap.add_argument("--strict", action="store_true",
                    help="replace any non-ACGTN in output with 'N' (after edits)")
    args = ap.parse_args(argv)

    bound = parse_bound(args.bound) if args.bound else None
    chrom_allow = set(args.chrom.split(",")) if args.chrom else None

    order, headers, seqs = fasta_read_all(args.fasta)
    fasta_contigs = order[:]  # contig IDs

    # Determine which CHROM values occur in the VCF (optionally filtered by --chrom)
    vcf_chroms = set()
    for chrom, pos, ref, alts in vcf_iter_records(args.vcf):
        if chrom_allow is not None and chrom not in chrom_allow:
            continue
        vcf_chroms.add(chrom)

    if not vcf_chroms:
        raise ValueError("No VCF records found after applying --chrom filter (if any).")

    mapping = choose_contig_mapping(vcf_chroms, fasta_contigs)
    missing = [c for c in sorted(vcf_chroms) if c not in mapping]
    if missing:
        print("Error: could not map these VCF CHROM values to any FASTA contig:", file=sys.stderr)
        for c in missing:
            print(f"  {c}", file=sys.stderr)
        print("Tip: ensure FASTA contig IDs match VCF CHROM (or use a single-contig FASTA).", file=sys.stderr)
        sys.exit(2)

    # Apply SNP REF bases
    total = 0
    used = 0
    skipped_indel = 0
    skipped_oob = 0
    skipped_bound = 0
    skipped_chrom = 0

    for chrom, pos, ref, alts in vcf_iter_records(args.vcf):
        if chrom_allow is not None and chrom not in chrom_allow:
            skipped_chrom += 1
            continue

        total += 1

        # SNP-only v1: require single-base REF and all ALTs single-base (multi-allelic allowed if all 1bp)
        if len(ref) != 1 or any(len(a) != 1 for a in alts if a != "."):
            skipped_indel += 1
            continue

        if bound is not None:
            start, stop = bound
            if not (start <= pos <= stop):
                skipped_bound += 1
                continue

        fasta_id = mapping[chrom]
        seq = seqs[fasta_id]
        idx = pos - 1  # standard VCF mapping

        if idx < 0 or idx >= len(seq):
            skipped_oob += 1
            continue

        base = ref.upper()
        if base not in "ACGTN":
            base = "N"

        seq[idx] = base
        used += 1

    # Optional strict cleanup
    if args.strict:
        for cid in order:
            s = seqs[cid]
            for i, ch in enumerate(s):
                c = ch.upper()
                if c not in "ACGTN":
                    s[i] = "N"
                else:
                    s[i] = c
    else:
        # normalize to upper for consistency
        for cid in order:
            seqs[cid] = [c.upper() for c in seqs[cid]]

    fasta_write_all(args.out, order, headers, seqs, wrap=60)

    print(f"VCF records processed: {total}", file=sys.stderr)
    print(f"SNP REF bases written: {used}", file=sys.stderr)
    if skipped_chrom:
        print(f"Skipped by --chrom:   {skipped_chrom}", file=sys.stderr)
    if skipped_bound:
        print(f"Skipped by --bound:   {skipped_bound}", file=sys.stderr)
    if skipped_indel:
        print(f"Skipped indels/len>1: {skipped_indel}", file=sys.stderr)
    if skipped_oob:
        print(f"Skipped out-of-range: {skipped_oob}", file=sys.stderr)
    print(f"Wrote output FASTA:   {args.out}", file=sys.stderr)


if __name__ == "__main__":
    main(sys.argv[1:])
