#! /usr/bin/env python3
"""p2p_ALOS2_SCAN_Frame — process a single ALOS-2 ScanSAR frame (5 subswaths)
through preprocess, samp_slc upsampling, per-subswath p2p, then merge.
Python port of `p2p_ALOS2_SCAN_Frame.csh` (Xiaohua Xu, 01/2019).

Usage: p2p_ALOS2_SCAN_Frame <master_stem> <aligned_stem> <config.txt> <parallel>
"""
import os, sys, glob, subprocess, multiprocessing
from gmtsar_lib import run


def _grep_field(path, key, col=3):
    if not os.path.isfile(path): return ""
    with open(path) as f:
        for line in f:
            if key in line:
                parts = line.split()
                if len(parts) >= col: return parts[col-1]
    return ""


def _rewrite_config(src, dst, overrides):
    """Replace `<key> = .*` lines; append missing keys. Mimics the sed chain in csh."""
    if not os.path.isfile(src):
        sys.exit(f"_rewrite_config: src config missing: {src} — staged config or "
                 f"bundled config.txt not found. Refusing to write empty {dst}.")
    with open(src) as f:
        lines = f.readlines()
    seen, out = set(), []
    for line in lines:
        stripped = line.lstrip()
        replaced = False
        for k, v in overrides.items():
            if stripped.startswith(k+" ") or stripped.startswith(k+"="):
                out.append(f"{k} = {v}\n"); seen.add(k); replaced = True; break
        if not replaced: out.append(line)
    for k, v in overrides.items():
        if k not in seen: out.append(f"{k} = {v}\n")
    with open(dst, "w") as f:
        f.writelines(out)


def _link_subswath_inputs(master, aligned, master_stem, aligned_stem, n):
    os.makedirs("raw", exist_ok=True)
    os.makedirs("topo", exist_ok=True)
    os.chdir("topo"); run("ln -sf ../../topo/dem.grd ."); os.chdir("..")
    os.chdir("raw")
    run(f"ln -sf ../../raw/{master_stem}-F{n} .")
    run(f"ln -sf ../../raw/{aligned_stem}-F{n} .")
    run(f"ln -sf ../../raw/LED-{master} ./LED-{master}-F{n}")
    run(f"ln -sf ../../raw/LED-{aligned} ./LED-{aligned}-F{n}")
    os.chdir("..")


def _process_one_subswath(n, master, aligned, master_arg, aligned_arg, abs_conf, case_dir, stage):
    sw = f"F{n}"
    os.makedirs(sw, exist_ok=True)
    os.chdir(sw)
    _link_subswath_inputs(master, aligned, master_arg, aligned_arg, n)
    # Config must be named "config.py" — p2p_processing does `import config`
    # after `sys.path.insert(0, os.getcwd())`, so the file is loaded as a
    # Python module from cwd.
    if stage < 2:
        # First call: preprocess only — skip stages 2-6 via separate skip_N=1.
        # The csh script uses `skip_stage = 2,3,4,5,6` (comma list), but the
        # Python p2p_processing reads skip_1..skip_6 as separate module attrs.
        _rewrite_config(abs_conf, "config.py", {
            "skip_2": "1", "skip_3": "1", "skip_4": "1", "skip_5": "1", "skip_6": "1"})
        run(f"p2p_processing ALOS2_SCAN {master_arg}-F{n} {aligned_arg}-F{n} config.py")
        os.chdir("raw")
        run(f"samp_slc {master_arg}-F{n} 3350 0")
        run(f"samp_slc {aligned_arg}-F{n} 3350 0")
        os.chdir("..")
    # Second call: skip stage 1 (preprocess already done).
    _rewrite_config(abs_conf, "config.py", {"skip_1": "1"})
    run(f"p2p_processing ALOS2_SCAN {master_arg}-F{n} {aligned_arg}-F{n} config.py")
    os.chdir(case_dir)


def _collect_pth_prm(swath):
    prms = sorted(glob.glob(f"../F{swath}/intf/*/*.PRM"))
    if len(prms) < 2:
        raise RuntimeError(f"F{swath}/intf/*/*.PRM: need 2 PRM files, found {len(prms)}")
    pth = os.path.dirname(prms[0]) + "/"
    return pth, os.path.basename(prms[0]), os.path.basename(prms[1])


def p2p_alos2_scan_frame():
    if len(sys.argv) != 5:
        sys.exit("Usage: p2p_ALOS2_SCAN_Frame <master_stem> <aligned_stem> <config.txt> <parallel>")
    master_arg, aligned_arg, conf, parallel = sys.argv[1:5]
    parallel = int(parallel)
    # csh: master = substr($1, 8) — strips "IMG-HH-".
    master = master_arg[7:]
    aligned = aligned_arg[7:]
    stage = int(_grep_field(conf, "proc_stage") or 1)
    iono = int(_grep_field(conf, "correct_iono") or 0)
    det_stitch = _grep_field(conf, "det_stitch") or "0"
    mask_water = int(_grep_field(conf, "mask_water") or 0)
    print(f"p2p_ALOS2_SCAN_Frame: master={master} aligned={aligned} stage={stage} parallel={parallel}")
    if iono != 0:
        sys.exit("p2p_ALOS2_SCAN_Frame: iono!=0 not ported yet.")

    case_dir = os.getcwd()
    abs_conf = os.path.abspath(conf)

    if stage < 5:
        args = [(n, master, aligned, master_arg, aligned_arg, abs_conf, case_dir, stage)
                for n in (1, 2, 3, 4, 5)]
        if parallel == 1:
            # Each subswath gets its own subprocess — chdir is process-local so
            # workers don't trip over each other. F1..F5 typically each peg one
            # core during p2p_processing, so 5-way is fine on multi-core boxes.
            with multiprocessing.Pool(processes=5) as pool:
                pool.starmap(_process_one_subswath, args)
        else:
            for a in args:
                _process_one_subswath(*a)

    # merge stage
    os.makedirs("merge", exist_ok=True)
    os.chdir("merge")
    run("ln -sf ../topo/dem.grd .")
    for g in glob.glob("../F1/intf/*/gauss*"):
        run(f"ln -sf {g} .")
    merge_conf = os.path.basename(abs_conf)
    _rewrite_config(abs_conf, merge_conf, {
        "threshold_geocode":"0", "threshold_snaphu":"0", "iono_skip_est":"1"})
    pths = {n:_collect_pth_prm(n) for n in (1,2,3,4,5)}
    # First merge: F1+F2+F3
    with open("tmp.filelist","w") as f:
        for n in (1,2,3):
            pth,m,s = pths[n]; f.write(f"{pth}:{m}:{s}\n")
    # merge_unwrap_geocode_tops derives stem via rsplit('.', 1)[0] — for
    # ALOS2_SCAN PRM names like 'IMG-HH-...-WBDR1.1__D-F1.PRM' this gives
    # 'IMG-HH-...-WBDR1.1__D-F1'. csh uses awk -F"." '{print $1}' which would
    # split at the FIRST '.' (giving '...WBDR1') — but we must match what the
    # py merge tool actually writes, not the csh convention.
    stem = pths[1][1].rsplit(".", 1)[0]
    recompute_lut = not os.path.exists("trans.dat")
    if recompute_lut: open("trans.dat","w").close()
    run(f"merge_unwrap_geocode_tops tmp.filelist {merge_conf} {det_stitch}")
    # Post-condition: merge1 MUST have produced phasefilt/corr/mask + the stem
    # PRM. Failing fast here saves the ~hr of wasted work if merge2 silently
    # processes a half-built tmp_first/.
    required_after_merge1 = ("phasefilt.grd", "corr.grd", "mask.grd", f"{stem}.PRM")
    missing = [f for f in required_after_merge1 if not os.path.exists(f)]
    if missing:
        sys.exit(f"merge1 produced incomplete output — missing {missing} in merge/. "
                 f"stem='{stem}', tmp.filelist line1='{open('tmp.filelist').readline().strip()}'")
    for fn in ("merge_log","tmp_phaselist"):
        if os.path.exists(fn): os.rename(fn, fn+"1")
    os.makedirs("tmp_first", exist_ok=True)
    for fn in ("phasefilt.grd","corr.grd","mask.grd"):
        os.rename(fn, f"tmp_first/{fn}")
    run(f"cp {stem}.PRM tmp_first/{stem}.PRM")
    # Second merge: tmp_first + F4 + F5
    with open("tmp2.filelist","w") as f:
        f.write(f"./tmp_first/:{stem}.PRM:{stem}.PRM\n")
        for n in (4,5):
            pth,m,s = pths[n]; f.write(f"{pth}:{m}:{s}\n")
    run(f"merge_unwrap_geocode_tops tmp2.filelist {merge_conf} {det_stitch}")
    # Post-condition: merge2 MUST have produced the final phasefilt/corr/mask
    # at merge/ top level. Otherwise the snaphu/geocode block below proceeds
    # against missing inputs and crashes deep in gmt with a non-obvious error.
    missing = [f for f in ("phasefilt.grd", "corr.grd", "mask.grd") if not os.path.exists(f)]
    if missing:
        sys.exit(f"merge2 produced incomplete output — missing {missing} in merge/. "
                 f"tmp2.filelist=\n{open('tmp2.filelist').read()}")
    for fn in ("merge_log","tmp_phaselist"):
        if os.path.exists(fn): os.rename(fn, fn+"2")
    if recompute_lut:
        if os.path.exists("trans.dat"): os.remove("trans.dat")
        if not os.path.exists("dem.grd"):
            sys.exit("p2p_ALOS2_SCAN_Frame: missing dem.grd in merge/")
        led = _grep_field(f"{stem}.PRM", "led_file")
        if not led:
            sys.exit(f"p2p_ALOS2_SCAN_Frame: led_file missing in {stem}.PRM — "
                     f"trans.dat would be junk; refusing to proceed.")
        led_src = f"{pths[1][0]}{led}"
        if not os.path.exists(led_src):
            sys.exit(f"p2p_ALOS2_SCAN_Frame: LED file not found at {led_src}")
        run(f"cp {led_src} .")
        print("Recomputing the projection LUT...")
        run(f"gmt grd2xyz --FORMAT_FLOAT_OUT=%lf dem.grd -s | SAT_llt2rat {stem}.PRM 1 -bod > trans.dat")
    os.chdir(case_dir)

    # final snaphu + geocode
    os.chdir("merge")
    run(f"cp ../{conf} .")
    threshold_snaphu = _grep_field(merge_conf, "threshold_snaphu") or "0"
    threshold_geocode = _grep_field(merge_conf, "threshold_geocode") or "0"
    region_cut = _grep_field(merge_conf, "region_cut") or ""
    defomax = _grep_field(merge_conf, "defomax") or "0"
    near_interp = _grep_field(merge_conf, "near_interp") or "0"
    if not region_cut:
        out = subprocess.check_output(["gmt","grdinfo","phasefilt.grd","-I-"], text=True)
        region_cut = out.strip()[2:20]
    if threshold_snaphu != "0":
        if mask_water == 1 and not os.path.exists("landmask_ra.grd"):
            run(f"landmask {region_cut}")
        print("SNAPHU - START")
        if near_interp == "1": run(f"snaphu_interp {threshold_snaphu} {defomax} {region_cut}")
        else: run(f"snaphu {threshold_snaphu} {defomax} {region_cut}")
        print("SNAPHU - END")
    else:
        print("SKIP UNWRAP PHASE")
    if threshold_geocode != "0":
        print("GEOCODE - START")
        run("proj_ra2ll trans.dat phasefilt.grd phasefilt_ll.grd")
        run("proj_ra2ll trans.dat corr.grd corr_ll.grd")
        run("gmt makecpt -Crainbow -T-3.15/3.15/0.05 -Z > phase.cpt")
        out = subprocess.check_output(["gmt","grdinfo","-C","corr.grd"], text=True)
        bt = out.split()[6]
        run(f"gmt makecpt -Cgray -T0/{bt}/0.05 -Z -M --COLOR_NAN=red > corr.cpt")
        run("grd2kml phasefilt_ll phase.cpt")
        run("grd2kml corr_ll corr.cpt")
        if os.path.exists("unwrap.grd"):
            run("gmt grdmath unwrap.grd mask.grd MUL = unwrap_mask.grd")
            run("proj_ra2ll trans.dat unwrap.grd unwrap_ll.grd")
            run("proj_ra2ll trans.dat unwrap_mask.grd unwrap_mask_ll.grd")
            out = subprocess.check_output(["gmt","grdinfo","-C","unwrap.grd"], text=True)
            parts = out.split(); bl, bt2 = parts[5], parts[6]
            run(f"gmt makecpt -T{bl}/{bt2}/0.5 -Z > unwrap.cpt")
            run("grd2kml unwrap_mask_ll unwrap.cpt")
            run("grd2kml unwrap_ll unwrap.cpt")
        print("GEOCODE END")
    else:
        print("SKIP GEOCODE")
    os.chdir(case_dir)


if __name__ == "__main__":
    p2p_alos2_scan_frame()
