%---------------------------------------------------------------------------%
% vim: ft=mercury ts=4 sw=4 et
%---------------------------------------------------------------------------%
% File: grid.m
%
% Grid topology abstraction for square and hexagonal grids.
%
% 🤖 Generated with [Claude Code](https://claude.ai/code)
%
% Co-authored-by: Claude <noreply@anthropic.com>
%
%---------------------------------------------------------------------------%

:- module grid.
:- interface.

:- import_module list.
:- import_module options.

%---------------------------------------------------------------------------%

    % A cell position on the grid.
    % For square grids: (x, y) coordinates.
    % For hex grids (flat-top): axial coordinates (q, r).
    %
:- type cell
    --->    cell(int, int).

    % Directions connecting adjacent cells.
    %
:- type direction
    --->    up
    ;       down
    ;       left          % square only
    ;       right         % square only
    ;       upper_left    % hex only
    ;       upper_right   % hex only
    ;       lower_left    % hex only
    ;       lower_right.  % hex only

    % An edge between two adjacent cells.
    % Represented internally as a cell and canonical direction.
    %
:- type edge.

%---------------------------------------------------------------------------%

    % Return all directions for the given topology.
    %
:- func directions(topology) = list(direction).

    % Return the canonical directions for the given topology.
    % Square: up, right
    % Hex: up, upper_right, lower_right
    %
:- func canonical_directions(topology) = list(direction).

    % Return the opposite direction.
    %
:- func opposite(direction) = direction.

    % Return the adjacent cell in the given direction.
    %
:- func adjacent(topology, cell, direction) = cell.

    % Return the edge between two adjacent cells.
    % Throws an exception if the cells are not adjacent.
    %
:- func get_edge_between(topology, cell, cell) = edge.

    % Return the edge between a cell and its neighbour in the given direction.
    %
:- func get_edge(topology, cell, direction) = edge.

    % Return the two cells on either side of an edge.
    %
:- pred edge_cells(topology::in, edge::in, cell::out, cell::out) is det.

    % Return the coordinate delta for a direction.
    % For square: delta for (X, Y) coordinates.
    % For hex: delta for (Q, R) coordinates.
    %
:- pred direction_delta(topology::in, direction::in, int::out, int::out) is det.

%---------------------------------------------------------------------------%
%---------------------------------------------------------------------------%

:- implementation.

:- import_module int.
:- import_module require.

%---------------------------------------------------------------------------%

:- type edge
    --->    edge(cell, direction).

%---------------------------------------------------------------------------%

directions(square) = [up, down, left, right].
directions(hex) = [up, down, upper_left, upper_right, lower_left, lower_right].

canonical_directions(square) = [up, right].
canonical_directions(hex) = [up, upper_right, lower_right].

%---------------------------------------------------------------------------%

opposite(up) = down.
opposite(down) = up.
opposite(left) = right.
opposite(right) = left.
opposite(upper_left) = lower_right.
opposite(lower_right) = upper_left.
opposite(upper_right) = lower_left.
opposite(lower_left) = upper_right.

%---------------------------------------------------------------------------%

adjacent(square, cell(X, Y), Dir) = Cell :-
    (
        Dir = up,
        Cell = cell(X, Y + 1)
    ;
        Dir = down,
        Cell = cell(X, Y - 1)
    ;
        Dir = left,
        Cell = cell(X - 1, Y)
    ;
        Dir = right,
        Cell = cell(X + 1, Y)
    ;
        ( Dir = upper_left
        ; Dir = upper_right
        ; Dir = lower_left
        ; Dir = lower_right
        ),
        unexpected($pred, "hex direction used with square topology")
    ).

adjacent(hex, cell(Q, R), Dir) = Cell :-
    % Hex axial coordinates with basis vectors:
    %   Q increases in upper_right direction (basis: upper_right)
    %   R increases in down direction (basis: down)
    % Other directions are combinations of these basis vectors.
    (
        Dir = up,
        Cell = cell(Q, R - 1)
    ;
        Dir = down,
        Cell = cell(Q, R + 1)
    ;
        Dir = upper_right,
        Cell = cell(Q + 1, R)
    ;
        Dir = lower_left,
        Cell = cell(Q - 1, R)
    ;
        Dir = upper_left,    % lower_left + up = (-1, 0) + (0, -1)
        Cell = cell(Q - 1, R - 1)
    ;
        Dir = lower_right,   % upper_right + down = (+1, 0) + (0, +1)
        Cell = cell(Q + 1, R + 1)
    ;
        ( Dir = left
        ; Dir = right
        ),
        unexpected($pred, "square direction used with hex topology")
    ).

%---------------------------------------------------------------------------%

get_edge(Topology, Cell, Dir) = Edge :-
    ( if is_canonical(Dir) then
        Edge = edge(Cell, Dir)
    else
        AdjCell = adjacent(Topology, Cell, Dir),
        CanonDir = opposite(Dir),
        Edge = edge(AdjCell, CanonDir)
    ).

get_edge_between(Topology, Cell1, Cell2) = Edge :-
    ( if find_direction(Topology, Cell1, Cell2, Dir) then
        Edge = get_edge(Topology, Cell1, Dir)
    else
        unexpected($pred, "cells are not adjacent")
    ).

:- pred is_canonical(direction::in) is semidet.

is_canonical(up).
is_canonical(right).
is_canonical(upper_right).
is_canonical(lower_right).

:- pred find_direction(topology::in, cell::in, cell::in, direction::out)
    is semidet.

find_direction(Topology, Cell1, Cell2, Dir) :-
    find_direction_loop(directions(Topology), Topology, Cell1, Cell2, Dir).

:- pred find_direction_loop(list(direction)::in, topology::in,
    cell::in, cell::in, direction::out) is semidet.

find_direction_loop([D | Ds], Topology, Cell1, Cell2, Dir) :-
    ( if adjacent(Topology, Cell1, D) = Cell2 then
        Dir = D
    else
        find_direction_loop(Ds, Topology, Cell1, Cell2, Dir)
    ).

%---------------------------------------------------------------------------%

edge_cells(Topology, edge(Cell, Dir), Cell, AdjCell) :-
    AdjCell = adjacent(Topology, Cell, Dir).

%---------------------------------------------------------------------------%

direction_delta(Topology, Dir, DX, DY) :-
    % Derive the delta from the adjacent function to ensure consistency.
    Cell0 = cell(0, 0),
    cell(DX, DY) = adjacent(Topology, Cell0, Dir).

%---------------------------------------------------------------------------%
:- end_module grid.
%---------------------------------------------------------------------------%
