#! /usr/bin/env python3
"""align_batch — geometric alignment of a stack of images to a common master.

Python port of csh align_batch.csh (K. Wang 2018). Aligns a set of images
to a common master using precise orbits + a DEM. Run from the top-level
case directory; creates / populates SLC/.

Algorithm (per aligned image):
  1. Build master.PRM with calc_dop_orb output appended, set fdd1=fddd1=0.
  2. Downsample DEM → flt.grd → topo.llt.
  3. SAT_llt2rat master.PRM → master.ratll.
  4. (RAW only) sarp master to make the master SLC.
  5. For each aligned image:
     - calc_dop_orb with master's earth_radius + fd1.
     - (RAW only) sarp aligned.
     - SAT_llt2rat aligned → aligned.ratll.
     - paste master.ratll aligned.ratll → offset.dat.
     - fitoffset 3 3 → resamp.
     - Optional secondary xcorr + fitoffset 1 1 pass when secondary_align=1.

Usage:  align_batch RAW|SLC secondary_align(0|1) data.in

data.in format:
  line 1: master image basename
  line 2+: aligned image basenames
"""
import os
import shutil
import sys

from gmtsar_lib import run, grep_value


def _replace_prm_field(prm, key, value):
    """Append a 'key = value' line to the PRM file (matches `echo X >> P`)."""
    with open(prm, "a") as f:
        f.write(f"{key:22s} = {value}\n")


def _grep_field3_pipe(prm, key, last=True):
    """grep KEY P | (tail -1) | awk '{print $3}'. last=False → first match."""
    matches = []
    with open(prm) as f:
        for line in f:
            if key in line:
                parts = line.split()
                if len(parts) >= 3:
                    matches.append(parts[2])
    return matches[-1] if (last and matches) else (matches[0] if matches else "")


def align_batch():
    if len(sys.argv) != 4:
        sys.exit(
            "Usage: align_batch RAW|SLC secondary_align(0|1) data.in\n"
            "  data.in: line 1 = master, line 2+ = aligned basenames."
        )
    data_type, secondary_str, data_in = sys.argv[1], sys.argv[2], sys.argv[3]
    if data_type not in ("RAW", "SLC"):
        sys.exit("data_type must be RAW or SLC")
    secondary = int(secondary_str)

    print("START ALIGN A STACK OF IMAGES\n")

    os.makedirs("SLC", exist_ok=True)
    run("cleanup SLC")

    # Read data.in
    with open(data_in) as f:
        all_names = [ln.strip() for ln in f if ln.strip()]
    if len(all_names) < 2:
        sys.exit("align_batch: need at least 1 master + 1 aligned in data.in")
    master = all_names[0]
    aligned_list = all_names[1:]

    # Stage raw inputs into SLC/
    os.chdir("SLC")
    run("ln -sf ../raw/*.PRM .")
    run("ln -sf ../raw/*.LED .")
    run(f"ln -sf ../raw/*.{'raw' if data_type == 'RAW' else 'SLC'} .")

    if not os.path.isfile("../topo/dem.grd"):
        sys.exit("DEM file not found at ../topo/dem.grd")

    print("Downsample the DEM data")
    run("gmt grdfilter ../topo/dem.grd -D3 -Fg2 -I12s -Ni -Gflt.grd")
    run("gmt grd2xyz --FORMAT_FLOAT_OUT=%lf flt.grd -s > topo.llt")

    # Build master PRM with calc_dop_orb appended
    print(f"calculating the SAT height for master {master}")
    run(f"mv {master}.PRM {master}.PRM0")
    run(f"calc_dop_orb {master}.PRM0 {master}.log 0")
    run(f"cat {master}.PRM0 {master}.log > {master}.PRM")
    _replace_prm_field(f"{master}.PRM", "fdd1", "0")
    _replace_prm_field(f"{master}.PRM", "fddd1", "0")
    run(f"rm -f {master}.log")

    earth_radius = _grep_field3_pipe(f"{master}.PRM", "earth_radius", last=True)
    fd1 = _grep_field3_pipe(f"{master}.PRM", "fd1", last=True)
    run(f"update_PRM {master}.PRM earth_radius {earth_radius}")
    run(f"update_PRM {master}.PRM fd1 {fd1}")

    run(f"SAT_llt2rat {master}.PRM 1 < topo.llt > master.ratll")
    rmax = grep_value(f"{master}.PRM", "num_rng_bins", 3)
    amax = grep_value(f"{master}.PRM", "num_lines", 3)

    if data_type == "RAW":
        print("Focusing the master - START")
        run(f"sarp {master}.PRM")
        print("Focusing the master - END")

    # Per-aligned loop
    n = len(aligned_list)
    for i, aligned in enumerate(aligned_list, start=1):
        print(f"working on {aligned}  [ {i} / {n} ]")

        run(f"mv {aligned}.PRM {aligned}.PRM0")
        run(f"calc_dop_orb {aligned}.PRM0 {aligned}.log {earth_radius}")
        run(f"cat {aligned}.PRM0 {aligned}.log > {aligned}.PRM")
        _replace_prm_field(f"{aligned}.PRM", "fdd1", "0")
        _replace_prm_field(f"{aligned}.PRM", "fddd1", "0")
        run(f"update_PRM {aligned}.PRM earth_radius {earth_radius}")
        run(f"update_PRM {aligned}.PRM fd1 {fd1}")
        run(f"rm -f {aligned}.log")

        if data_type == "RAW":
            print("Focusing the aligned - START")
            run(f"sarp {aligned}.PRM")
            print("Focusing the aligned - END")

        run(f"SAT_llt2rat {aligned}.PRM 1 < topo.llt > aligned.ratll")
        run("paste master.ratll aligned.ratll | "
            "awk '{printf(\"%.6f %.6f %.6f %.6f %d\\n\", $1, $6 - $1, $2, $7 - $2, \"100\")}' "
            "> tmp.dat")
        run(f"awk '{{if($1 > 0 && $1 < {rmax} && $3 > 0 && $3 < {amax}) print $0 }}' "
            f"< tmp.dat > offset.dat")
        run("awk '{ printf(\"%.6f %.6f %.6f \\n\",$1,$3,$2) }' < offset.dat > r.xyz")
        run("awk '{ printf(\"%.6f %.6f %.6f \\n\",$1,$3,$4) }' < offset.dat > a.xyz")

        run(f"fitoffset 3 3 offset.dat >> {aligned}.PRM")
        run(f"resamp {master}.PRM {aligned}.PRM {aligned}.PRMresamp {aligned}.SLCresamp 4")
        run(f"mv {aligned}.SLC {aligned}.SLC_old")
        run(f"mv {aligned}.SLCresamp {aligned}.SLC")
        run(f"cp {aligned}.PRMresamp {aligned}.PRM")

        if secondary > 0:
            run(f"cp {aligned}.PRM tmp.PRM")
            for key in ("rshift", "ashift"):
                run(f"update_PRM tmp.PRM {key} 0")
            for key in ("sub_int_r", "sub_int_a", "stretch_r",
                        "stretch_a", "a_stretch_r", "a_stretch_a"):
                run(f"update_PRM tmp.PRM {key} 0.0")
            run(f"xcorr {master}.PRM tmp.PRM -xsearch 128 -ysearch 256 -nx 15 -ny 30")
            run(f"fitoffset 1 1 freq_xcorr.dat 20 >> tmp.PRM")
            run(f"resamp {master}.PRM tmp.PRM {aligned}.PRMresamp {aligned}.SLCresamp 4")
            run(f"mv {aligned}.SLC {aligned}.SLC_old")
            run(f"mv {aligned}.SLCresamp {aligned}.SLC")
            # Legacy keeps the old {aligned}.PRM intact for phasediff compatibility.

    # Cleanup
    run("rm -f tmp* *old *resamp *.raw *.PRM0")
    os.chdir("..")
    print("\nEND ALIGN A STACK OF IMAGES")


if __name__ == "__main__":
    align_batch()
