#! /usr/bin/env python3
"""select_pairs — pick interferogram pairs within time/baseline thresholds.

Python port of csh select_pairs.csh (X. Xu 2016).

Given baseline_table.dat (one row per scene, columns:
  orb  ?  day_of_year  ?  bperp  ...
— matching the format produced by baseline_table), enumerate all (i,j) pairs
where i.day < j.day and (j.day - i.day) < dt and |i.bperp - j.bperp| < db.

Writes intf.in (one "<ref>:<rep>" line per pair) and baseline.pdf showing
the network of selected pairs.

Usage:  select_pairs baseline_table.dat threshold_time threshold_baseline

Note: legacy csh used a hardcoded `+2014` year offset for the plot x-axis
(see line 28 of select_pairs.csh). Preserved here for behavioral parity;
the plot x-axis labels will be off for non-2014-era data. To fix without
breaking compatibility, pass a future --year-offset arg.
"""
import math
import os
import subprocess
import sys
from gmtsar_lib import run


def select_pairs():
    if len(sys.argv) != 4:
        sys.exit(
            "Usage: select_pairs baseline_table.dat threshold_time threshold_baseline\n"
            "  outputs: intf.in (+ baseline.pdf network plot)"
        )
    table_path = sys.argv[1]
    dt = float(sys.argv[2])
    db = float(sys.argv[3])

    # Parse: legacy uses awk field positions $1..$5; preserve that.
    rows = []
    with open(table_path) as f:
        for line in f:
            parts = line.split()
            if len(parts) < 5:
                continue
            name = parts[0]
            try:
                t = float(parts[2])
                b = float(parts[4])
            except ValueError:
                continue
            rows.append((name, t, b))

    if os.path.isfile("intf.in"):
        os.remove("intf.in")

    # Generate (year, baseline, name) text table for the plot. Legacy uses
    # `2014 + $3/365.25` for the x-axis, which is wrong for non-S1 epochs.
    with open("text", "w") as f:
        for name, t, b in rows:
            f.write(f"{2014 + t / 365.25} {b} {name}\n")

    region_raw = subprocess.run(
        ["gmt", "gmtinfo", "text", "-C"],
        check=False, stdout=subprocess.PIPE
    ).stdout.decode("utf-8").split()
    if len(region_raw) >= 4:
        x0, x1, y0, y1 = (float(region_raw[0]) - 0.5,
                          float(region_raw[1]) + 0.5,
                          float(region_raw[2]) - 50,
                          float(region_raw[3]) + 50)
        R = f"-R{x0}/{x1}/{y0}/{y1}"
    else:
        R = "-R0/1/0/1"

    run(f"gmt pstext text -JX8.8i/6.8i {R} -D0.2/0.2 -X1.5i -Y1i "
        f"-K -N -F+f8,Helvetica+j5 > baseline.ps")

    # O(N^2) pair enumeration with threshold filter.
    selected = []
    edges = []
    for i, (n1, t1, b1) in enumerate(rows):
        for j, (n2, t2, b2) in enumerate(rows):
            if t1 < t2 and (t2 - t1) < dt:
                if abs(b1 - b2) < db:
                    selected.append((n1, n2))
                    edges.append(((t1 / 365.25 + 2014, b1),
                                  (t2 / 365.25 + 2014, b2)))

    with open("intf.in", "w") as f:
        for n1, n2 in selected:
            f.write(f"{n1}:{n2}\n")

    # Draw network edges in a single batched psxy call (one segment per edge).
    if edges:
        with open("edges.txt", "w") as f:
            for (x1, y1), (x2, y2) in edges:
                f.write(f"> -W0.5p\n{x1} {y1}\n{x2} {y2}\n")
        run("gmt psxy edges.txt -R -J -K -O >> baseline.ps")
        run("rm -f edges.txt")

    # Final scatter pass + axis labels (closes the PS).
    run("awk '{print $1,$2}' < text > text2")
    run(f"gmt psxy text2 -Sp0.2c -G0 -R -JX "
        f"-Ba0.5:\"year\":/a50g00f25:\"baseline (m)\":WSen -O >> baseline.ps")
    run("rm -f text text2")
    run("gmt psconvert baseline.ps -Tf -A")
    run("rm -f baseline.ps")

    print(f"select_pairs: wrote {len(selected)} pairs to intf.in")


if __name__ == "__main__":
    select_pairs()
