#! /usr/bin/env python3
"""merge_unwrap_geocode_tops — multi-burst S1 TOPS merge + unwrap + geocode.

Python port of csh merge_unwrap_geocode_tops.csh (X. Xu 2016). Merges
2- or 3-subswath TOPS interferograms (phasefilt, corr, mask), optionally
auto-computes stitching column positions, then unwraps and geocodes.

This eliminates the LAST csh shell-out from any Python entry point
(p2p_S1_TOPS_Frame:119).

Usage:  merge_unwrap_geocode_tops inputfile config_file [det_stitch]

inputfile format:
  Swath1_Path:Swath1_master.PRM:Swath1_repeat.PRM
  Swath2_Path:Swath2_master.PRM:Swath2_repeat.PRM
  Swath3_Path:Swath3_master.PRM:Swath3_repeat.PRM
"""
import os
import subprocess
import sys
from gmtsar_lib import run, grep_value, check_file_report


def _capture(cmd):
    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE,
                          check=False).stdout.decode("utf-8").strip()


def _grep_field(prm, key, field=3, last=True):
    """grep <key> <prm> [ | tail -1] | awk '{print $field}'."""
    matches = []
    with open(prm) as f:
        for line in f:
            if key in line:
                parts = line.split()
                if len(parts) >= field:
                    matches.append(parts[field - 1])
    return (matches[-1] if matches else "") if last else (matches[0] if matches else "")


def _get_config(path, key, default=""):
    return _grep_field(path, key, 3, last=False) or default


def _build_swath_tmp_prm(pth, prm, prm2):
    """Per-swath: cp prm → tmp.PRM, copy rshift from prm2, possibly bump first_sample."""
    rshift = _grep_field(os.path.join(pth, prm2), "rshift", 3, last=True)
    fs1 = _grep_field(os.path.join(pth, prm), "first_sample", 3, last=False)
    fs2 = _grep_field(os.path.join(pth, prm2), "first_sample", 3, last=False)
    cwd_save = os.getcwd()
    os.chdir(pth)
    run(f"cp {prm} tmp.PRM")
    if fs2 and fs1 and float(fs2) > float(fs1):
        run(f"update_PRM tmp.PRM first_sample {fs2}")
    run(f"update_PRM tmp.PRM rshift {rshift}")
    os.chdir(cwd_save)


def _grdinfo_field(grd, field, flags="-C"):
    out = _capture(f"gmt grdinfo {flags} {grd}")
    parts = out.split()
    return parts[field - 1] if len(parts) >= field else ""


def _det_stitch_2(filelist_lines):
    """Compute n1 for the 2-subswath stitch position. Returns (n1, n2='')."""
    pth0 = filelist_lines[0].split(":")[0]
    near1 = _grep_field(f"{pth0}/tmp.PRM", "near", 3, last=False)
    rng1  = _grep_field(f"{pth0}/tmp.PRM", "num_rng_bins", 3, last=False)
    fs    = _grep_field(f"{pth0}/tmp.PRM", "rng_samp_rate", 3, last=True)
    run(f"gmt grdcut {pth0}/phasefilt.grd -Z+N -Gtmp.grd")
    xm1  = _grdinfo_field(f"{pth0}/phasefilt.grd", 3)
    xc1  = _grdinfo_field("tmp.grd", 3)
    incx = _grdinfo_field("tmp.grd", 8)
    n12 = int((float(xm1) - float(xc1)) / float(incx))

    pth1 = filelist_lines[-1].split(":")[0]
    near2 = _grep_field(f"{pth1}/tmp.PRM", "near", 3, last=False)
    run(f"gmt grdcut {pth1}/phasefilt.grd -Z+N -Gtmp.grd")
    x01  = _grdinfo_field("tmp.grd", 2)
    incx = _grdinfo_field("tmp.grd", 8)
    n21 = int(float(x01) / float(incx))
    ovl12 = int((float(rng1) - (float(near2) - float(near1)) / (299792458.0 / float(fs) / 2.0)) / float(incx))
    n1 = (ovl12 - n12 - n21) // 2 + n21
    run("rm -f tmp.grd")
    if n1 <= 0:
        print("WARNING: Stitching position estimated to be zero — check merged grids")
        return ("", "")
    return (str(n1), "0")


def _det_stitch_3(filelist_lines):
    """Compute (n1, n2) stitch positions for 3 subswaths."""
    pth0 = filelist_lines[0].split(":")[0]
    near1 = _grep_field(f"{pth0}/tmp.PRM", "near", 3, last=False)
    rng1  = _grep_field(f"{pth0}/tmp.PRM", "num_rng_bins", 3, last=False)
    fs    = _grep_field(f"{pth0}/tmp.PRM", "rng_samp_rate", 3, last=True)
    run(f"gmt grdcut {pth0}/phasefilt.grd -Z+N -Gtmp.grd")
    xm1  = _grdinfo_field(f"{pth0}/phasefilt.grd", 3)
    xc1  = _grdinfo_field("tmp.grd", 3)
    incx = _grdinfo_field("tmp.grd", 8)
    n12 = int((float(xm1) - float(xc1)) / float(incx))

    pth1 = filelist_lines[1].split(":")[0]
    near2 = _grep_field(f"{pth1}/tmp.PRM", "near", 3, last=False)
    rng2  = _grep_field(f"{pth1}/tmp.PRM", "num_rng_bins", 3, last=False)
    run(f"gmt grdcut {pth1}/phasefilt.grd -Z+N -Gtmp.grd")
    x02  = _grdinfo_field("tmp.grd", 2)
    incx = _grdinfo_field("tmp.grd", 8)
    n21 = int(float(x02) / float(incx))
    ovl12 = int((float(rng1) - (float(near2) - float(near1)) / (299792458.0 / float(fs) / 2.0)) / float(incx))
    n1 = (ovl12 - n12 - n21) // 2 + n21
    xm2 = _grdinfo_field(f"{pth1}/phasefilt.grd", 3)
    xc2 = _grdinfo_field("tmp.grd", 3)
    n22 = int((float(xm2) - float(xc2)) / float(incx))

    pth2 = filelist_lines[2].split(":")[0]
    near3 = _grep_field(f"{pth2}/tmp.PRM", "near", 3, last=False)
    run(f"gmt grdcut {pth2}/phasefilt.grd -Z+N -Gtmp.grd")
    x03  = _grdinfo_field("tmp.grd", 2)
    incx = _grdinfo_field("tmp.grd", 8)
    n31 = int(float(x03) / float(incx))
    ovl23 = int((float(rng2) - (float(near3) - float(near2)) / (299792458.0 / float(fs) / 2.0)) / float(incx))
    n2 = (ovl23 - n22 - n31) // 2 + n31

    run("rm -f tmp.grd")
    if n2 == 0:
        print("WARNING: Stitching positions estimated to be zero — check merged grids")
        return ("", "")
    return (str(n1), str(n2))


def merge_unwrap_geocode_tops():
    if len(sys.argv) not in (3, 4):
        sys.exit(
            "Usage: merge_unwrap_geocode_tops inputfile config_file [det_stitch]\n"
            "  inputfile lines: SwathPath:master.PRM:repeat.PRM (2 or 3 swaths)\n"
            "  Each path must contain phasefilt.grd, corr.grd, mask.grd."
        )
    inputfile, config = sys.argv[1], sys.argv[2]
    det_stitch = int(sys.argv[3]) if len(sys.argv) == 4 else 0

    for f in ("tmp_phaselist", "tmp_corrlist", "tmp_masklist"):
        if os.path.isfile(f):
            os.remove(f)

    if not check_file_report("dem.grd"):
        sys.exit("Please link dem.grd to current folder")

    n1, n2 = ("", "") if det_stitch != 1 else ("0", "0")

    with open(inputfile) as f:
        filelist_lines = [ln.strip() for ln in f if ":" in ln]

    # Build per-swath tmp.PRM files and the merge filelists.
    for line in filelist_lines:
        pth, prm, prm2 = line.split(":")
        _build_swath_tmp_prm(pth, prm, prm2)
        with open("tmp_phaselist", "a") as fl:
            fl.write(f"{pth}tmp.PRM:{pth}phasefilt.grd\n")
        with open("tmp_corrlist", "a") as fl:
            fl.write(f"{pth}tmp.PRM:{pth}corr.grd\n")
        with open("tmp_masklist", "a") as fl:
            fl.write(f"{pth}tmp.PRM:{pth}mask.grd\n")

    head = filelist_lines[0].split(":")
    pth, stem = head[0], head[1].rsplit(".", 1)[0]

    print("\nMerging START")
    if det_stitch == 1:
        nl = len(filelist_lines)
        if nl == 2:
            n1, n2 = _det_stitch_2(filelist_lines)
        elif nl == 3:
            n1, n2 = _det_stitch_3(filelist_lines)
        else:
            sys.exit(f"Incorrect number of records in input filelist: {nl}")
        print(f"Stitching positions set to {n1} {n2}")

    if n1 and int(n1) > 5:
        run(f"merge_swath tmp_phaselist phasefilt.grd {stem} {n1} {n2} > merge_log")
        run(f"merge_swath tmp_corrlist  corr.grd  {n1} {n2} > merge_log_corr")
        run(f"merge_swath tmp_masklist  mask.grd  {n1} {n2} > merge_log_mask")
    else:
        run(f"merge_swath tmp_phaselist phasefilt.grd {stem} > merge_log")
        run("merge_swath tmp_corrlist  corr.grd      > merge_log_corr")
        run("merge_swath tmp_masklist  mask.grd      > merge_log_mask")
    print("Merging END\n")

    iono       = int(_get_config(config, "correct_iono", "0") or 0)
    skip_iono  = int(_get_config(config, "iono_skip_est", "0") or 0)
    if iono != 0 and skip_iono == 0:
        if not check_file_report("ph_iono_orig.grd"):
            print("Need ph_iono_orig.grd to correct ionosphere ...")
        else:
            print("Correcting ionosphere ...")
            run("gmt grdsample ph_iono_orig.grd -Rphasefilt.grd -Gtmp.grd")
            run("gmt grdmath phasefilt.grd tmp.grd SUB PI ADD 2 PI MUL MOD PI SUB = tmp2.grd")
            run("mv phasefilt.grd phasefilt_orig.grd")
            run("mv tmp2.grd phasefilt.grd")
            run("rm -f tmp.grd")

    # Recompute the geocoding LUT if missing — needs rshift=0 on the master.
    if not check_file_report("trans.dat"):
        led = _grep_field(f"{pth}{stem}.PRM", "led_file", 3, last=False)
        run(f"cp {pth}{led} .")
        print("Recomputing the projection LUT...")
        rshift = _grep_field(f"{stem}.PRM", "rshift", 3, last=True)
        run(f"update_PRM {stem}.PRM rshift 0")
        run(f"gmt grd2xyz --FORMAT_FLOAT_OUT=%lf dem.grd -s | SAT_llt2rat {stem}.PRM 1 -bod > trans.dat")
        run(f"update_PRM {stem}.PRM rshift {rshift}")

    threshold_snaphu  = float(_get_config(config, "threshold_snaphu",  "0") or 0)
    threshold_geocode = float(_get_config(config, "threshold_geocode", "0") or 0)
    region_cut        = _get_config(config, "region_cut", "")
    switch_land       = int(_get_config(config, "switch_land",  "0") or 0)
    defomax           = int(_get_config(config, "defomax",      "0") or 0)
    near_interp       = int(_get_config(config, "near_interp",  "0") or 0)
    mask_water        = int(_get_config(config, "mask_water",   "0") or 0)

    if not region_cut:
        region_cut = _capture("gmt grdinfo phasefilt.grd -I- | cut -c3-20")

    if threshold_snaphu != 0:
        if (mask_water == 1 or switch_land == 1) and not check_file_report("landmask_ra.grd"):
            run(f"landmask {region_cut}")
        print(f"\nSNAPHU - START (threshold {threshold_snaphu})")
        interp_flag = 1 if near_interp == 1 else 0
        run(f"snaphu {threshold_snaphu} {defomax} {interp_flag} {region_cut}")
        print("SNAPHU - END")
    else:
        print("\nSKIP UNWRAP PHASE")

    if threshold_geocode != 0:
        print("\nGEOCODE - START")
        run(f"proj_ra2ll trans.dat phasefilt.grd phasefilt_ll.grd")
        run(f"gmt grdmath corr.grd {threshold_geocode} GE 0 NAN mask.grd MUL = mask2.grd")
        run("gmt grdmath phasefilt.grd mask2.grd MUL = phasefilt_mask.grd")
        run("proj_ra2ll trans.dat phasefilt_mask.grd phasefilt_mask_ll.grd")
        run("proj_ra2ll trans.dat corr.grd corr_ll.grd")
        run("gmt makecpt -Crainbow -T-3.15/3.15/0.05 -Z > phase.cpt")
        BT = _grdinfo_field("corr.grd", 7)
        run(f"gmt makecpt -Cgray -T0/{BT}/0.05 -Z -M --COLOR_NAN=red > corr.cpt")
        run("grd2kml phasefilt_ll phase.cpt")
        run("grd2kml phasefilt_mask_ll phase.cpt")
        run("grd2kml corr_ll corr.cpt")

        if check_file_report("unwrap.grd"):
            run("gmt grdmath unwrap.grd mask2.grd MUL = unwrap_mask.grd")
            wavel = _capture("grep wavelength *.PRM | awk '{print $3}' | head -1")
            run(f"gmt grdmath unwrap_mask.grd {wavel} MUL -79.58 MUL = los.grd")
            run("proj_ra2ll trans.dat unwrap.grd unwrap_ll.grd")
            run("proj_ra2ll trans.dat unwrap_mask.grd unwrap_mask_ll.grd")
            run("proj_ra2ll trans.dat los.grd los_ll.grd")
            BT = _grdinfo_field("unwrap.grd", 7)
            BL = _grdinfo_field("unwrap.grd", 6)
            run(f"gmt makecpt -T{BL}/{BT}/0.5 -Z > unwrap.cpt")
            run("grd2kml unwrap_mask_ll unwrap.cpt")
            run("grd2kml unwrap_ll unwrap.cpt")
            BT = _grdinfo_field("los.grd", 7)
            BL = _grdinfo_field("los.grd", 6)
            run(f"gmt makecpt -T{BL}/{BT}/2 -Z > los.cpt")
            run("grd2kml los_ll los.cpt")
        print("GEOCODE END")

    run("rm -f tmp_phaselist tmp_corrlist tmp_masklist *.eps *.bb")


if __name__ == "__main__":
    merge_unwrap_geocode_tops()
