#! /usr/bin/env python3
"""intf_batch — form a stack of interferograms from a pairs list.

Python port of csh intf_batch.csh (X. Tong + D. Sandwell 2010; A. Hogrelius
ENVI_SLC support 2017). Reads intf.in (each line "ref:rep") and runs the
single-pair workflow (intf + filter + snaphu + geocode) for each pair,
sharing topo_ra across pairs.

Usage:  intf_batch SAT intf.in batch.config

Supported SATs: ALOS, ENVI, ENVI_SLC, ERS, TSX
"""
import os
import subprocess
import sys

from gmtsar_lib import run, grep_value, check_file_report


_SUPPORTED_SATS = {"ALOS", "ENVI", "ERS", "ENVI_SLC", "TSX"}


def _get_config(path, key, default=""):
    """grep <key> <path> | awk '{print $3}'."""
    with open(path) as f:
        for line in f:
            if key in line:
                parts = line.split()
                if len(parts) >= 3:
                    return parts[2]
    return default


def _topo_stage(master, SAT, topo_phase, shift_topo):
    """Stage 1: make topo_ra (+ optional topo_shift)."""
    run("cleanup topo")
    if topo_phase == 0:
        print("NO TOPOPHASE IS SUBSTRACTED")
        return
    if topo_phase != 1:
        sys.exit(f"Wrong parameter: topo_phase {topo_phase}")

    print("DEM2TOPOPHASE - START")
    os.chdir("topo")
    run(f"cp ../SLC/{master}.PRM master.PRM")
    run(f"ln -sf ../raw/{master}.LED .")
    if not check_file_report("dem.grd"):
        sys.exit("no DEM file found: dem.grd")
    run("dem2topo_ra master.PRM dem.grd")
    os.chdir("..")
    print("DEM2TOPOPHASE - END")

    if shift_topo == 0:
        print("NO TOPOPHASE SHIFT")
        return
    if shift_topo != 1:
        sys.exit(f"Wrong parameter: shift_topo {shift_topo}")

    print("OFFSET_TOPO - START")
    os.chdir("SLC")
    if SAT in ("ALOS", "TSX"):
        rng_samp_rate_raw = grep_value(f"{master}.PRM", "rng_samp_rate", 3)
        try:
            rng_samp_rate = int(float(rng_samp_rate_raw))
        except ValueError:
            sys.exit("Undefined rng_samp_rate in the master PRM file")
        if rng_samp_rate > 25_000_000:
            print("processing ALOS FBS data"); rng = 2
        else:
            print("processing ALOS FBD data"); rng = 1
        run(f"slc2amp {master}.PRM {rng} amp-{master}.grd")
    elif SAT in ("ERS", "ENVI", "ENVI_SLC"):
        run(f"slc2amp {master}.PRM 1 amp-{master}.grd")
    os.chdir("../topo")
    run(f"ln -sf ../SLC/amp-{master}.grd .")
    run(f"offset_topo amp-{master}.grd topo_ra.grd 0 0 7 topo_shift.grd")
    os.chdir("..")
    print("OFFSET_TOPO - END")


def _intf_one_pair(ref, rep, filter_wavelen, dec, topo_phase, shift_topo,
                   threshold_snaphu, defomax, region_cut, near_interp,
                   mask_water, switch_land, threshold_geocode):
    """Per-pair work: intf + filter + snaphu + geocode in intf/<ref_id>_<rep_id>/."""
    ref_id = int(float(grep_value(f"./SLC/{ref}.PRM", "SC_clock_start", 3)))
    rep_id = int(float(grep_value(f"./SLC/{rep}.PRM", "SC_clock_start", 3)))
    sub = f"{ref_id}_{rep_id}"

    print(f"\nINTF + FILTER - START (pair {sub})")
    os.chdir("intf")
    os.makedirs(sub, exist_ok=True)
    os.chdir(sub)
    run(f"ln -sf ../../raw/{ref}.LED .")
    run(f"ln -sf ../../raw/{rep}.LED .")
    run(f"ln -sf ../../SLC/{ref}.SLC .")
    run(f"ln -sf ../../SLC/{rep}.SLC .")
    run(f"cp ../../SLC/{ref}.PRM .")
    run(f"cp ../../SLC/{rep}.PRM .")

    if topo_phase == 1:
        topo_file = "topo_shift.grd" if shift_topo == 1 else "topo_ra.grd"
        run(f"ln -sf ../../topo/{topo_file} .")
        run(f"intf {ref}.PRM {rep}.PRM -topo {topo_file}")
    else:
        run(f"intf {ref}.PRM {rep}.PRM")
    run(f"filter {ref}.PRM {rep}.PRM {filter_wavelen} {dec}")
    print("INTF + FILTER - END")

    if not region_cut:
        region_cut = subprocess.run(
            "gmt grdinfo phase.grd -I- | cut -c3-20",
            shell=True, stdout=subprocess.PIPE,
        ).stdout.decode("utf-8", "replace").strip()

    if threshold_snaphu != 0:
        if mask_water == 1 or switch_land == 1:
            os.chdir("../../topo")
            if not check_file_report("landmask_ra.grd"):
                run(f"landmask {region_cut}")
            os.chdir(f"../intf/{sub}")
            run("ln -sf ../../topo/landmask_ra.grd .")
        print(f"\nSNAPHU - START (threshold {threshold_snaphu})")
        interp_flag = 1 if near_interp == 1 else 0
        run(f"snaphu {threshold_snaphu} {defomax} {interp_flag} {region_cut}")
        print("SNAPHU - END")
    else:
        print("SKIP UNWRAP PHASE")

    print("\nGEOCODE - START")
    run("rm -f raln.grd ralt.grd")
    if topo_phase != 1:
        sys.exit("topo_ra is needed to geocode")
    run("rm -f trans.dat")
    run("ln -sf ../../topo/trans.dat .")
    print(f"threshold_geocode: {threshold_geocode}")
    run(f"geocode {threshold_geocode}")
    print("GEOCODE - END")

    os.chdir("../..")


def intf_batch():
    if len(sys.argv) != 4:
        sys.exit(
            "Usage: intf_batch SAT intf.in batch.config\n"
            "  SAT: ALOS / ENVI / ENVI_SLC / ERS / TSX\n"
            "  intf.in: one 'ref:rep' per line"
        )
    SAT, intf_in, config = sys.argv[1], sys.argv[2], sys.argv[3]
    if SAT not in _SUPPORTED_SATS:
        sys.exit(f"SAT must be one of: {' '.join(sorted(_SUPPORTED_SATS))}")
    if not os.path.isfile(intf_in):
        sys.exit(f"no input file: {intf_in}")
    if not os.path.isfile(config):
        sys.exit(f"no config file: {config}")

    stage             = int(_get_config(config, "proc_stage", "1") or 1)
    master            = _get_config(config, "master_image")
    if not master:
        sys.exit("master image not set in config")
    filter_wavelen    = _get_config(config, "filter_wavelength", "200") or "200"
    dec               = _get_config(config, "dec_factor", "2") or "2"
    topo_phase        = int(_get_config(config, "topo_phase", "1") or 1)
    shift_topo        = int(_get_config(config, "shift_topo", "0") or 0)
    threshold_snaphu  = float(_get_config(config, "threshold_snaphu", "0") or 0)
    threshold_geocode = float(_get_config(config, "threshold_geocode", "0") or 0)
    region_cut        = _get_config(config, "region_cut", "")
    switch_land       = int(_get_config(config, "switch_land", "0") or 0)
    defomax           = int(_get_config(config, "defomax", "0") or 0)
    near_interp       = int(_get_config(config, "near_interp", "0") or 0)
    mask_water        = int(_get_config(config, "mask_water", "0") or 0)

    if stage <= 1:
        _topo_stage(master, SAT, topo_phase, shift_topo)

    if stage <= 2:
        print("\nSTART FORM A STACK OF INTERFEROGRAMS")
        os.makedirs("intf", exist_ok=True)
        with open(intf_in) as f:
            pairs = [ln.strip() for ln in f if ":" in ln]
        for line in pairs:
            ref, rep = line.split(":", 1)
            _intf_one_pair(
                ref, rep, filter_wavelen, dec, topo_phase, shift_topo,
                threshold_snaphu, defomax, region_cut, near_interp,
                mask_water, switch_land, threshold_geocode,
            )
        print("\nEND FORM A STACK OF INTERFEROGRAMS")


if __name__ == "__main__":
    intf_batch()
