# Given the position of the pegs, we solve the puzzle

from board import Board, generate_orientations
from typing import Tuple, List
import pieces
import sys
import numpy as np
import colorama
import hashlib

def backtrack(board: Board, pieces_list: List[str], piece_index: int,
              orientations: dict, solutions: List[np.ndarray]) -> None:
    """
    Backtracking to place each piece in the board in all possible ways.
    """
    
    # if full, check if solution
    if piece_index == len(pieces_list):
        if board.is_full():
            solutions.append(board.grid.copy())
        return

    piece_name = pieces_list[piece_index]
    all_orientations = orientations[piece_name]

    for orientation in all_orientations:
        rows, cols = orientation.shape
        for row in range(board.rows - rows + 1):
            for col in range(board.cols - cols + 1):
                if board.placeable(orientation, (row, col)):
                    board.place(orientation, (row, col), piece_index + 1)
                    backtrack(board, pieces_list, piece_index + 1, orientations, solutions)
                    board.remove(orientation, (row, col))


def solve(xPin: Tuple[int, int], yPin: Tuple[int, int], zPin: Tuple[int, int]):
    """
    solve the puzzle, avoiding the pins at the specified positions.

    Args:
        xPin (Tuple[int, int]): row, col of the 'x' pin
        yPin (Tuple[int, int]): row, col of the 'y' pin
        zPin (Tuple[int, int]): row, col of the 'z' pin
    """

    board = Board((5, 5))

    pinned_positions = [xPin, yPin, zPin]
    for (pr, pc) in pinned_positions:
        board.grid[pr, pc] = -1 

    orientations_map = {}
    for piece_name, piece_matrix in pieces.all_pieces.items():
        orientations_map[piece_name] = generate_orientations(piece_matrix)
    piece_names = list(pieces.all_pieces.keys())
    solutions = []
    backtrack(board, piece_names, 0, orientations_map, solutions)
    
    return solutions


def colorful_solution(xPin, yPin, zPin) -> str:
    """
    Return a colorful representation of the solution.
    """

    if not (solutions := solve(xPin, yPin, zPin)):
        return print("No solution found.")


    for i, solution in enumerate(solutions):
        rows, cols = solution.shape
        result = ""
        piece_id_to_color = {
            -1: colorama.Back.BLACK,
            1: colorama.Back.GREEN,
            2: colorama.Back.BLUE,
            3: colorama.Back.YELLOW,
            4: colorama.Back.MAGENTA,
            5: colorama.Back.CYAN,
            6: colorama.Back.RED,
        }
        for r in range(rows):
            for c in range(cols):
                piece_id = solution[r, c]
                if piece_id == 0:
                    result += colorama.Back.LIGHTWHITE_EX + "  "
                else:
                    result += piece_id_to_color[piece_id] + "  "
                
                result += colorama.Back.RESET
            result += "\n"
        print(f"Solution {i+1}")
        print(result)
        print("--------------------")
    print(f"Total solutions: {len(solutions)}")

def check_rotations(Ax, Ay, Bx, By, Cx, Cy, checked):
    """
    Check if any rotation of these three pins was already in `checked`.
    If yes, return True; if not, add them and return False.
    We assume a 5x5 board with valid row/col in [0..4].
    """

    def rotate_90(r, c):
        # Rotate (r, c) by 90 degrees on a 5x5 board => (c, 4 - r)
        return (c, 4 - r)

    # Start with the 3 pin positions
    base_pins = [(Ax, Ay), (Bx, By), (Cx, Cy)]
    # Sort them so (pin1, pin2, pin3) is in a canonical order

    # Generate all 4 rotations
    rotations = []
    current = base_pins
    for _ in range(4):
        # Rotate each pin in `current` by 90 degrees
        current = [rotate_90(r, c) for (r, c) in current]
        # Sort again so we have a canonical ordering
        rotations.append(tuple(current))

    # Check if any rotation is already in `checked`
    for pins_tuple in rotations:
        if pins_tuple in checked:
            return True

    # If none were found, add *all* rotations to `checked`
    # so we never re-check or re-add them in the future.
    for pins_tuple in rotations:
        checked.add(pins_tuple)

    return False


def try_all_pegs_naive():
    from tqdm import tqdm    
    all_solutions = []
    # Solution entry {x: (row, col), y: (row, col), z: (row, col), solution_count: int}
    # 5^4 = 625 iterations
    # 3^6 = 729 iterations
    # 5^6 = 15625 iterations
    checked = set()
    for xPinA in tqdm(range(5)):
        for yPinA in (range(5)):
            for xPinB in range(5):
                for yPinB in range(5):
                    for xPinC in (range(5)):
                        for yPinC in range(5):
                            if (xPinA, yPinA) == (xPinB, yPinB) or (xPinA, yPinA) == (xPinC, yPinC) or (xPinB, yPinB) == (xPinC, yPinC):
                                continue

                            if check_rotations(xPinA, yPinA, xPinB, yPinB, xPinC, yPinC, checked):
                                continue
                            
                            solutions = solve((xPinA, yPinA), (xPinB, yPinB), (xPinC, yPinC))
                            import sys
                            if solutions:
                                print({
                                    'x': (xPinA, yPinA),
                                    'y': (xPinB, yPinB),
                                    'z': (xPinC, yPinC),
                                    'solution_count': len(solutions)
                                }, file=sys.stdout, flush=True)

    
if __name__ == "__main__":
    colorful_solution((0, 0), (4, 4), (4, 0))
    # try_all_pegs_naive()