Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
.pytest_cache/
.coverage
htmlcov/
coverage.xml
.tox/
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# TensorFlow
*.pb
*.pbtxt
*.ckpt.*
tensorboard_logs/
saved_models/

# PyTorch
*.pth
*.pt

# IDE files
.vscode/
.idea/
*.swp
*.swo
*~

# OS files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db

# Claude Code settings
.claude/
2,434 changes: 2,434 additions & 0 deletions poetry.lock

Large diffs are not rendered by default.

87 changes: 87 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
[tool.poetry]
name = "background-matting-tensorflow"
version = "0.1.0"
description = "TensorFlow implementation of Real-Time High-Resolution Background Matting"
authors = ["Background Matting Team"]
readme = "README.md"
packages = [{include = "model"}]

[tool.poetry.dependencies]
python = "^3.8"
tensorflow-cpu = "^2.8.0"
torch = "^1.10.0"

[[tool.poetry.source]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
priority = "explicit"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
pytest-cov = "^4.1.0"
pytest-mock = "^3.11.0"

[tool.poetry.scripts]
test = "pytest:main"
tests = "pytest:main"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py", "*_test.py"]
python_classes = ["Test*"]
python_functions = ["test_*"]
addopts = [
"--strict-markers",
"--strict-config",
"--verbose",
"-ra",
"--cov=model",
"--cov-report=term-missing",
"--cov-report=html:htmlcov",
"--cov-report=xml:coverage.xml",
"--cov-fail-under=10"
]
markers = [
"unit: Unit tests for individual components",
"integration: Integration tests for component interactions",
"slow: Slow running tests that may take several seconds"
]
filterwarnings = [
"ignore::DeprecationWarning",
"ignore::PendingDeprecationWarning"
]

[tool.coverage.run]
source = ["model"]
branch = true
omit = [
"tests/*",
"*/test_*",
"*/__pycache__/*",
"*/migrations/*",
"*/venv/*",
"*/.venv/*"
]

[tool.coverage.report]
exclude_lines = [
"pragma: no cover",
"def __repr__",
"if self.debug:",
"if settings.DEBUG",
"raise AssertionError",
"raise NotImplementedError",
"if 0:",
"if __name__ == .__main__.:",
"class .*\\bProtocol\\):",
"@(abc\\.)?abstractmethod"
]
precision = 2
show_missing = true

[tool.coverage.html]
directory = "htmlcov"
1 change: 1 addition & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Testing package initialization
163 changes: 163 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Shared pytest fixtures for the testing suite."""

import pytest
import tensorflow as tf
import numpy as np
import tempfile
import shutil
import os
from pathlib import Path
from unittest.mock import Mock, MagicMock


@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
temp_path = tempfile.mkdtemp()
yield Path(temp_path)
shutil.rmtree(temp_path)


@pytest.fixture
def sample_image_shape():
"""Standard image shape for testing."""
return (1, 256, 256, 3)


@pytest.fixture
def sample_tensor(sample_image_shape):
"""Create a sample TensorFlow tensor for testing."""
return tf.random.normal(sample_image_shape, dtype=tf.float32)


@pytest.fixture
def sample_image_pair(sample_image_shape):
"""Create a pair of sample images (source and background) for testing."""
src = tf.random.normal(sample_image_shape, dtype=tf.float32)
bgr = tf.random.normal(sample_image_shape, dtype=tf.float32)
return src, bgr


@pytest.fixture
def mock_tensorflow_model():
"""Mock TensorFlow model for testing."""
mock_model = Mock(spec=tf.keras.Model)
mock_model.predict.return_value = [
np.random.random((1, 256, 256, 1)), # pha
np.random.random((1, 256, 256, 3)), # fgr
np.random.random((1, 256, 256, 1)), # err
np.random.random((1, 256, 256, 32)) # hid
]
return mock_model


@pytest.fixture
def mock_torch_weights():
"""Mock PyTorch weights dictionary for testing."""
return {
'backbone.conv1.weight': np.random.random((64, 6, 7, 7)),
'backbone.bn1.weight': np.random.random((64,)),
'backbone.bn1.bias': np.random.random((64,)),
'aspp.convs.0.weight': np.random.random((256, 2048, 1, 1)),
'decoder.conv.weight': np.random.random((48, 256, 3, 3))
}


@pytest.fixture
def config_dict():
"""Sample configuration dictionary for testing."""
return {
'backbone': 'resnet50',
'backbone_scale': 0.25,
'refine_mode': 'sampling',
'refine_sample_pixels': 80000,
'refine_threshold': 0.7
}


@pytest.fixture(autouse=True)
def reset_tensorflow():
"""Reset TensorFlow state between tests."""
tf.keras.backend.clear_session()
yield
tf.keras.backend.clear_session()


@pytest.fixture
def disable_mixed_precision():
"""Disable mixed precision for consistent testing."""
original_policy = tf.keras.mixed_precision.global_policy()
tf.keras.mixed_precision.set_global_policy('float32')
yield
tf.keras.mixed_precision.set_global_policy(original_policy)


@pytest.fixture
def mock_file_system(temp_dir):
"""Create mock file system structure for testing."""
weights_dir = temp_dir / "weights"
weights_dir.mkdir()

# Create mock weight files
(weights_dir / "model.pth").touch()
(weights_dir / "checkpoint.pth").touch()

return {
'root': temp_dir,
'weights_dir': weights_dir,
'model_path': weights_dir / "model.pth",
'checkpoint_path': weights_dir / "checkpoint.pth"
}


@pytest.fixture
def numpy_random_seed():
"""Set consistent numpy random seed for reproducible tests."""
original_state = np.random.get_state()
np.random.seed(42)
yield
np.random.set_state(original_state)


@pytest.fixture
def tensorflow_random_seed():
"""Set consistent TensorFlow random seed for reproducible tests."""
tf.random.set_seed(42)
yield


@pytest.fixture
def mock_training_mode():
"""Mock training mode context for testing."""
class TrainingModeContext:
def __init__(self):
self.training = False

def set_training(self, mode: bool):
self.training = mode
return self

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

return TrainingModeContext()


@pytest.fixture
def large_image_shape():
"""Larger image shape for integration testing."""
return (1, 1080, 1920, 3)


@pytest.fixture
def batch_image_shapes():
"""Various batch sizes and image shapes for testing."""
return [
(1, 256, 256, 3),
(2, 256, 256, 3),
(4, 512, 512, 3),
(1, 1080, 1920, 3)
]
Loading