#! /usr/bin/env python3
"""stack_coherence_mask — stack correlation grids, output a mask_def.grd.

Python port of csh stack_coherence_mask.csh. Reads a list of correlation
.grd files, computes the per-pixel mean, then thresholds to produce a
mask_def.grd (1 above threshold, NaN below).

Usage:  stack_coherence_mask grid_list average_corr_threshold
"""
import sys
from gmtsar_lib import run


def stack_coherence_mask():
    if len(sys.argv) != 3:
        sys.exit(
            "Usage: stack_coherence_mask grid_list average_corr_threshold\n"
            "  All listed corr grids must have consistent dimensions."
        )
    grid_list, threshold = sys.argv[1], sys.argv[2]

    with open(grid_list) as f:
        files = [ln.strip() for ln in f if ln.strip()]
    if not files:
        sys.exit("stack_coherence_mask: empty input list")

    # Initialize mask_def.grd to zeros with same dimensions as first file.
    run(f"cp {files[0]} ./mask_def.grd")
    run("gmt grdmath mask_def.grd 0 MUL = mask_def.grd")

    # Sum all files (legacy iterates over the whole list — matches csh).
    for fpath in files:
        run(f"gmt grdmath {fpath} mask_def.grd ADD = mask_def.grd")

    # Divide by count → mean; threshold → binary mask.
    run(f"gmt grdmath mask_def.grd {len(files)} DIV = mask_def.grd")
    run(f"gmt grdmath mask_def.grd {threshold} GE 0 NAN = mask_def.grd")


if __name__ == "__main__":
    stack_coherence_mask()
