Skip to content

feature: CachedEngine, useful for tests #1157

@kraktus

Description

@kraktus

Not sure if it's in the scope of the lib, but almost always when I create tests I want to cache the evaluation of the engine, to avoid flakiness.

Here is my (wonky, file-based) current implementation that I use for async,

import asyncio
import inspect
import zlib

import chess

from typing import List, Union, Dict, Any, Callable, Tuple
from pathlib import Path

from chess import WHITE, BLACK, Move
from chess.engine import Score, Cp, Mate, PovScore, InfoDict

STOCKFISH = "stockfish"

ANALYSE_SIGN = inspect.signature(chess.engine.Protocol.analyse)


def get_checksum_args(*args, **kwargs) -> int:
    """
    Calculate a checksum for the given arguments.
    This is used to identify unique calls to the analyse method.
    """
    # `apply_defaults` does not seem necessary as it would be considered a breaking change?
    original_dict = ANALYSE_SIGN.bind(*args, **kwargs).arguments
    # remove self from the arguments, because it contains the pid changing every time
    # deepcopy is not possible, due to asyncio shenanigans
    checksum_dict = {k: v for k, v in original_dict.items() if k != "self"}
    return str(checksum_dict).encode("utf-8")


class CachedEngine(chess.engine.UciProtocol):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__used_checksums = set()
        # named after cassettes in VCR
        self.__diskette_dir = Path("diskettes")
        self.__diskette_dir.mkdir(exist_ok=True)

    async def analyse(self, *args, **kwargs) -> Union[List[InfoDict], InfoDict]:
        print(f"signature {ANALYSE_SIGN}")
        print(f"Analysing with args: {args}, kwargs: {kwargs}")
        checksum_arg = get_checksum_args(self, *args, **kwargs)
        checksum = zlib.adler32(checksum_arg)
        self.__used_checksums.add(checksum)
        path = self.__diskette_dir / f"{checksum}.py"
        print(f"checksum of args {checksum_arg}, is {checksum}")
        if path.exists():
            with open(path) as f:
                return eval(f.read())
        res = await super().analyse(*args, **kwargs)
        with open(path, "w") as f:
            f.write(f"#{checksum_arg}\n")
            f.write(str(res))
        return res

    def list_unused_evals(self) -> List[int]:
        # list all files in the diskette directory
        return [
            int(x.stem)
            for x in self.__diskette_dir.iterdir()
            if int(x.stem) not in self.__used_checksums
        ]


# pasted from python-chess source code
async def popen_uci(
    command: Union[str, List[str]], *, setpgrp: bool = False, **popen_args: Any
) -> Tuple[asyncio.SubprocessTransport, CachedEngine]:
    transport, protocol = await CachedEngine.popen(
        command, setpgrp=setpgrp, **popen_args
    )
    try:
        await protocol.initialize()
    except:
        transport.close()
        raise
    return transport, protocol


async def test():
    transport, engine = await popen_uci(STOCKFISH)
    info = await engine.analyse(chess.Board(), chess.engine.Limit(time=5))
    print(info)
    await engine.quit()


if __name__ == "__main__":
    print("#" * 80)
    # unittest.main()
    asyncio.run(test())

and for SimpleEngine, as found in lichess-puzzler (less general than above)

class CachedEngine(SimpleEngine):


    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.used_checksums = set()
        # named after cassettes in VCR
        self.diskette_dir = Path("diskettes")
        self.diskette_dir.mkdir(exist_ok=True)

    # a more general implementation should use the `inspect` module and `Signature.bind`
    def analyse(self, board: Board, multipv: int, limit: chess.engine.Limit) -> Union[List[InfoDict], InfoDict]:
        checksum_arg = f"{board.fen()} {multipv} {limit}".encode()
        checksum = zlib.adler32(checksum_arg)
        self.used_checksums.add(checksum)
        path = self.diskette_dir / f"{checksum}.py"
        print(f"checksum of args {checksum_arg}, is {checksum}")
        if path.exists():
            with open(path) as f:
                return eval(f.read())
        res = super().analyse(board=board,multipv=multipv,limit=limit)
        with open(path, "w") as f:
            f.write(f"#{checksum_arg}\n")
            f.write(str(res))
        return res

    def list_unused_evals(self) -> List[int]:
        # list all files in the diskette directory
        return [int(x.stem) for x in self.diskette_dir.iterdir() if int(x.stem) not in self.used_checksums]

I can see a more general abstract class CachedEngine where only get_cache(checksum: int) -> info | None set_cache(checksum: int, info) -> None need to be implemented, with a default file-based also provided.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions