Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 76 additions & 30 deletions chess/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,28 @@ def piece_name(piece_type: PieceType) -> str:
"P": "♙", "p": "♟",
}

File: TypeAlias = int
FILE_A: File = 0
FILE_B: File = 1
FILE_C: File = 2
FILE_D: File = 3
FILE_E: File = 4
FILE_F: File = 5
FILE_G: File = 6
FILE_H: File = 7
FILES = [FILE_A, FILE_B, FILE_C, FILE_D, FILE_E, FILE_F, FILE_G, FILE_H]
FILE_NAMES = ["a", "b", "c", "d", "e", "f", "g", "h"]

Rank: TypeAlias = int
RANK_1: Rank = 0
RANK_2: Rank = 1
RANK_3: Rank = 2
RANK_4: Rank = 3
RANK_5: Rank = 4
RANK_6: Rank = 5
RANK_7: Rank = 6
RANK_8: Rank = 7
RANKS = [RANK_1, RANK_2, RANK_3, RANK_4, RANK_5, RANK_6, RANK_7, RANK_8]
RANK_NAMES = ["1", "2", "3", "4", "5", "6", "7", "8"]

STARTING_FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
Expand Down Expand Up @@ -251,15 +271,41 @@ def square_name(square: Square) -> str:
"""Gets the name of the square, like ``a3``."""
return SQUARE_NAMES[square]

def square(file_index: int, rank_index: int) -> Square:
def square(file_index: File, rank_index: Rank) -> Square:
"""Gets a square number by file and rank index."""
return rank_index * 8 + file_index

def square_file(square: Square) -> int:
def parse_file(name: str) -> File:
"""
Gets the file index for the given file *name*
(e.g., ``a`` returns ``0``).

:raises: :exc:`ValueError` if the file name is invalid.
"""
return FILE_NAMES.index(name)

def file_name(file: File) -> str:
"""Gets the name of the file, like ``a``."""
return FILE_NAMES[file]

def parse_rank(name: str) -> File:
"""
Gets the rank index for the given rank *name*
(e.g., ``1`` returns ``0``).

:raises: :exc:`ValueError` if the rank name is invalid.
"""
return FILE_NAMES.index(name)

def rank_name(rank: Rank) -> str:
"""Gets the name of the rank, like ``1``."""
return FILE_NAMES[rank]

def square_file(square: Square) -> File:
"""Gets the file index of the square where ``0`` is the a-file."""
return square & 7

def square_rank(square: Square) -> int:
def square_rank(square: Square) -> Rank:
"""Gets the rank index of the square where ``0`` is the first rank."""
return square >> 3

Expand Down Expand Up @@ -376,24 +422,24 @@ def square_mirror(square: Square) -> Square:
BB_LIGHT_SQUARES: Bitboard = 0x55aa_55aa_55aa_55aa
BB_DARK_SQUARES: Bitboard = 0xaa55_aa55_aa55_aa55

BB_FILE_A: Bitboard = 0x0101_0101_0101_0101 << 0
BB_FILE_B: Bitboard = 0x0101_0101_0101_0101 << 1
BB_FILE_C: Bitboard = 0x0101_0101_0101_0101 << 2
BB_FILE_D: Bitboard = 0x0101_0101_0101_0101 << 3
BB_FILE_E: Bitboard = 0x0101_0101_0101_0101 << 4
BB_FILE_F: Bitboard = 0x0101_0101_0101_0101 << 5
BB_FILE_G: Bitboard = 0x0101_0101_0101_0101 << 6
BB_FILE_H: Bitboard = 0x0101_0101_0101_0101 << 7
BB_FILE_A: Bitboard = 0x0101_0101_0101_0101 << FILE_A
BB_FILE_B: Bitboard = 0x0101_0101_0101_0101 << FILE_B
BB_FILE_C: Bitboard = 0x0101_0101_0101_0101 << FILE_C
BB_FILE_D: Bitboard = 0x0101_0101_0101_0101 << FILE_D
BB_FILE_E: Bitboard = 0x0101_0101_0101_0101 << FILE_E
BB_FILE_F: Bitboard = 0x0101_0101_0101_0101 << FILE_F
BB_FILE_G: Bitboard = 0x0101_0101_0101_0101 << FILE_G
BB_FILE_H: Bitboard = 0x0101_0101_0101_0101 << FILE_H
BB_FILES: List[Bitboard] = [BB_FILE_A, BB_FILE_B, BB_FILE_C, BB_FILE_D, BB_FILE_E, BB_FILE_F, BB_FILE_G, BB_FILE_H]

BB_RANK_1: Bitboard = 0xff << (8 * 0)
BB_RANK_2: Bitboard = 0xff << (8 * 1)
BB_RANK_3: Bitboard = 0xff << (8 * 2)
BB_RANK_4: Bitboard = 0xff << (8 * 3)
BB_RANK_5: Bitboard = 0xff << (8 * 4)
BB_RANK_6: Bitboard = 0xff << (8 * 5)
BB_RANK_7: Bitboard = 0xff << (8 * 6)
BB_RANK_8: Bitboard = 0xff << (8 * 7)
BB_RANK_1: Bitboard = 0xff << (8 * RANK_1)
BB_RANK_2: Bitboard = 0xff << (8 * RANK_2)
BB_RANK_3: Bitboard = 0xff << (8 * RANK_3)
BB_RANK_4: Bitboard = 0xff << (8 * RANK_4)
BB_RANK_5: Bitboard = 0xff << (8 * RANK_5)
BB_RANK_6: Bitboard = 0xff << (8 * RANK_6)
BB_RANK_7: Bitboard = 0xff << (8 * RANK_7)
BB_RANK_8: Bitboard = 0xff << (8 * RANK_8)
BB_RANKS: List[Bitboard] = [BB_RANK_1, BB_RANK_2, BB_RANK_3, BB_RANK_4, BB_RANK_5, BB_RANK_6, BB_RANK_7, BB_RANK_8]

BB_BACKRANKS: Bitboard = BB_RANK_1 | BB_RANK_8
Expand Down Expand Up @@ -1847,7 +1893,7 @@ def generate_pseudo_legal_moves(self, from_mask: Bitboard = BB_ALL, to_mask: Bit
self.occupied_co[not self.turn] & to_mask)

for to_square in scan_reversed(targets):
if square_rank(to_square) in [0, 7]:
if square_rank(to_square) in [RANK_1, RANK_8]:
yield Move(from_square, to_square, QUEEN)
yield Move(from_square, to_square, ROOK)
yield Move(from_square, to_square, BISHOP)
Expand All @@ -1870,7 +1916,7 @@ def generate_pseudo_legal_moves(self, from_mask: Bitboard = BB_ALL, to_mask: Bit
for to_square in scan_reversed(single_moves):
from_square = to_square + (8 if self.turn == BLACK else -8)

if square_rank(to_square) in [0, 7]:
if square_rank(to_square) in [RANK_1, RANK_8]:
yield Move(from_square, to_square, QUEEN)
yield Move(from_square, to_square, ROOK)
yield Move(from_square, to_square, BISHOP)
Expand All @@ -1897,7 +1943,7 @@ def generate_pseudo_legal_ep(self, from_mask: Bitboard = BB_ALL, to_mask: Bitboa
capturers = (
self.pawns & self.occupied_co[self.turn] & from_mask &
BB_PAWN_ATTACKS[not self.turn][self.ep_square] &
BB_RANKS[4 if self.turn else 3])
BB_RANKS[RANK_5 if self.turn else RANK_4])

for capturer in scan_reversed(capturers):
yield Move(capturer, self.ep_square)
Expand Down Expand Up @@ -1977,9 +2023,9 @@ def is_pseudo_legal(self, move: Move) -> bool:
if piece != PAWN:
return False

if self.turn == WHITE and square_rank(move.to_square) != 7:
if self.turn == WHITE and square_rank(move.to_square) != RANK_8:
return False
elif self.turn == BLACK and square_rank(move.to_square) != 0:
elif self.turn == BLACK and square_rank(move.to_square) != RANK_1:
return False

# Handle castling.
Expand Down Expand Up @@ -2401,18 +2447,18 @@ def push(self, move: Move) -> None:
else:
self.castling_rights &= ~BB_RANK_8
elif captured_piece_type == KING and not self.promoted & to_bb:
if self.turn == WHITE and square_rank(move.to_square) == 7:
if self.turn == WHITE and square_rank(move.to_square) == RANK_8:
self.castling_rights &= ~BB_RANK_8
elif self.turn == BLACK and square_rank(move.to_square) == 0:
elif self.turn == BLACK and square_rank(move.to_square) == RANK_1:
self.castling_rights &= ~BB_RANK_1

# Handle special pawn moves.
if piece_type == PAWN:
diff = move.to_square - move.from_square

if diff == 16 and square_rank(move.from_square) == 1:
if diff == 16 and square_rank(move.from_square) == RANK_2:
self.ep_square = move.from_square + 8
elif diff == -16 and square_rank(move.from_square) == 6:
elif diff == -16 and square_rank(move.from_square) == RANK_7:
self.ep_square = move.from_square - 8
elif move.to_square == ep_square and abs(diff) in [7, 9] and not captured_piece_type:
# Remove pawns captured en passant.
Expand Down Expand Up @@ -3605,11 +3651,11 @@ def _valid_ep_square(self) -> Optional[Square]:
return None

if self.turn == WHITE:
ep_rank = 5
ep_rank = RANK_6
pawn_mask = shift_down(BB_SQUARES[self.ep_square])
seventh_rank_mask = shift_up(BB_SQUARES[self.ep_square])
else:
ep_rank = 2
ep_rank = RANK_3
pawn_mask = shift_up(BB_SQUARES[self.ep_square])
seventh_rank_mask = shift_down(BB_SQUARES[self.ep_square])

Expand Down
8 changes: 4 additions & 4 deletions chess/gaviota.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def idx_is_empty(x: int) -> int:
def flip_type(x: chess.Square, y: chess.Square) -> int:
ret = 0

if chess.square_file(x) > 3:
if chess.square_file(x) > chess.FILE_D:
x = flip_we(x)
y = flip_we(y)
ret |= 1

if chess.square_rank(x) > 3:
if chess.square_rank(x) > chess.RANK_4:
x = flip_ns(x)
y = flip_ns(y)
ret |= 2
Expand Down Expand Up @@ -351,11 +351,11 @@ def init_ppidx() -> Tuple[List[List[int]], List[int], List[int]]:


def norm_kkindex(x: chess.Square, y: chess.Square) -> Tuple[int, int]:
if chess.square_file(x) > 3:
if chess.square_file(x) > chess.FILE_D:
x = flip_we(x)
y = flip_we(y)

if chess.square_rank(x) > 3:
if chess.square_rank(x) > chess.RANK_4:
x = flip_ns(x)
y = flip_ns(y)

Expand Down
2 changes: 1 addition & 1 deletion chess/syzygy.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ def calc_symlen(self, d: PairsData, s: int, tmp: List[int]) -> None:
d.symlen[s] = d.symlen[s1] + d.symlen[s2] + 1
tmp[s] = 1

def pawn_file(self, pos: List[chess.Square]) -> int:
def pawn_file(self, pos: List[chess.Square]) -> chess.File:
for i in range(1, self.pawns[0]):
if FLAP[pos[0]] > FLAP[pos[i]]:
pos[0], pos[i] = pos[i], pos[0]
Expand Down