Python + FastAPI + SQLAlchemy + DependencyInjectorで実装するWebAPI

  • OS: Xubuntu
  • Python: 3.10.x
  • Shell: zsh

プロジェクト作成

環境用意

poetry new python-fastapi-sqlalchemy
cd python-fastapi-sqlalchemy
python -m venv venv
. venv/bin/activate
pip install poetry
deactivate
. venv/bin/activate
poetry install

コード

pyproject.toml

[tool.poetry]
name = "fastapi-sqlalchemy-v1"
version = "0.1.0"
description = ""
authors = ["pecolynx <pecolynx@gmail.com>"]
readme = "README.md"
packages = [{include = "fastapi_sqlalchemy_v1"}]


[tool.poetry.dependencies]
python = "^3.10"
fastapi = "^0.95.0"
sqlalchemy = {extras = ["mypy"], version = "^2.0.7"}
uvicorn = "^0.21.1"
pyaml-env = "^1.2.1"


[tool.poetry.group.dev.dependencies]
isort = "^5.12.0"
ruff = "^0.0.257"
black = "^23.1.0"
flake8 = "^6.0.0"
mypy = "^1.1.1"


[tool.poetry.group.test.dependencies]
pytest = "^7.2.2"
httpx = "^0.23.3"
pytest-cov = "^4.0.0"
pytest-mock = "^3.10.0"


[tool.isort]
profile = "black"
skip_glob = ["*_pb2.py", "*_pb2_grpc.py"]


[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
  /(
      \.eggs         # exclude a few common directories in the
    | \.git          # root of the project
    | \.hg
    | \.mypy_cache
    | \.tox
    | \.venv
    | _build
    | buck-out
    | build
    | dist
    | venv
    | alembic
  )/
  | .*_pb2.py  # exclude autogenerated Protocol Buffer files anywhere in the project
  | .*_pb2_grpc.py  # exclude autogenerated Protocol Buffer files anywhere in the project
)
'''

[tool.ruff]
exclude = [
    ".bzr",
    ".direnv",
    ".eggs",
    ".git",
    ".hg",
    ".mypy_cache",
    ".nox",
    ".pants.d",
    ".pytype",
    ".ruff_cache",
    ".svn",
    ".tox",
    ".venv",
    "__pypackages__",
    "_build",
    "buck-out",
    "build",
    "dist",
    "node_modules",
    "venv",
    "alembic",
]

[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"
exclude = [
    'tests',
]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

setup.cfg

[flake8]
max-line-length = 88
exclude =
    .git,
    __pycache__,
    venv,
    alembic,

Makefile

CMD:=poetry run
APP_NAME:=fastapi_sqlalchemy_v1
IMAGE_NAME:=fastapi-sqlalchemy-v1
DB_NAME:=dev.db.sqlite3
SHELL=/bin/bash

.PHONY: all
all:
	$(CMD) isort .
	$(CMD) black .
	$(CMD) flake8 .
	$(CMD) ruff check .
	$(CMD) mypy .

.PHONY: format
format:
	$(CMD) isort .
	$(CMD) black .

.PHONY: type
type:
	$(CMD) mypy .

.PHONY: check
check:
	$(CMD) flake8 .
	$(CMD) ruff check .

.PHONY: test
test:
	CONFIG_FILE_PATH=tests/test_config.yml \
	$(CMD) pytest --cov=$(APP_NAME) --cov-branch --cov-report=term --cov-report=xml -v -s

.PHONY: run
run:
	CONFIG_FILE_PATH=$(APP_NAME)/config.yml \
	$(CMD) python $(APP_NAME)/main.py

.PHONY: uvicorn
uvicorn:
	LOG_FORMAT=JSON \
	CONFIG_FILE_PATH=$(APP_NAME)/config.yml \
	$(CMD) uvicorn $(APP_NAME).main:app --reload

fastapi_sqlalchemy_v1/book.py

from datetime import datetime


class Book:
    def __init__(self, book_id: int, name: str, lang2: str, created_at: datetime):
        self.__book_id = book_id
        self.__name = name
        self.__lang2 = lang2
        self.__created_at = created_at

    @property
    def book_id(self):
        return self.__book_id

    @property
    def name(self):
        return self.__name

    @property
    def lang2(self):
        return self.__lang2

    @property
    def created_at(self):
        return self.__created_at

fastapi_sqlalchemy_v1/book_repository.py

from typing import Optional

from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Integer, Text
from sqlalchemy.orm import Session
from sqlalchemy.sql import func

from fastapi_sqlalchemy_v1.book import Book
from fastapi_sqlalchemy_v1.database import Base


class BookNotFoundError(Exception):
    def __init__(self, key: str, value: str):
        self.key = key
        self.value = value


class BookOtherError(Exception):
    pass


class BookAddParameter(BaseModel):
    name: str = Field(..., min_length=1, max_length=20)
    lang2: str = Field(..., min_length=2, max_length=2)


class BookDbEntity(Base):
    __tablename__ = "book"

    id = Column(Integer, primary_key=True)
    name = Column(Text, nullable=False)
    lang2 = Column(Text, nullable=False)
    created_at = Column(DateTime, nullable=False, server_default=func.now())

    def to_model(self) -> Optional[Book]:
        if self.id and self.name and self.lang2 and self.created_at:
            return Book(
                book_id=self.id,
                name=self.name,
                lang2=self.lang2,
                created_at=self.created_at,
            )
        return None


class BookRepository:
    def __init__(self, session: Session):
        self.__session = session

    def find_book_by_id(self, book_id: int) -> Book:
        book_entity: BookDbEntity | None = (
            self.__session.query(BookDbEntity)
            .filter(BookDbEntity.id == book_id)
            .first()
        )
        if book_entity is None:
            raise BookNotFoundError("id", str(book_id))
        book = book_entity.to_model()
        if book:
            return book
        raise BookOtherError()

    def add_book(self, book_add_param: BookAddParameter) -> int:
        book_entity = BookDbEntity(name=book_add_param.name, lang2=book_add_param.lang2)
        self.__session.add(book_entity)
        self.__session.flush()
        self.__session.refresh(book_entity)
        if book_entity.id:
            return book_entity.id
        raise BookOtherError()

fastapi_sqlalchemy_v1/database.py

import logging
from typing import Generator

from sqlalchemy import create_engine, orm
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, scoped_session

logger = logging.getLogger(__name__)

Base = declarative_base()


class Database:
    _engine: Engine
    _session_factory: scoped_session

    @staticmethod
    def init(db_url: str, echo: bool) -> None:
        Database._engine = create_engine(db_url, echo=echo, pool_pre_ping=True)
        Database._session_factory = orm.scoped_session(
            orm.sessionmaker(
                autocommit=False,
                autoflush=False,
                bind=Database._engine,
            ),
        )

    @staticmethod
    def create_database() -> None:
        Base.metadata.create_all(Database._engine)  # type: ignore

    @staticmethod
    def get_session() -> Generator[Session, None, None]:
        session: Session = Database._session_factory()
        try:
            yield session
            logger.info("commit")
            session.commit()
        except Exception:
            logger.exception("Session rollback because of exception")
            session.rollback()
            raise
        finally:
            logger.info("close")
            session.close()

fastapi_sqlalchemy_v1/log.py

import json
import logging
from logging import LogRecord

from fastapi_sqlalchemy_v1.config import Logger


def get_app_log(record: LogRecord):
    json_obj = {
        "name": record.name,
        "level": record.levelname,
        "type": "app",
        "timestamp": record.asctime,
        "pathname": record.pathname,
        "line": record.lineno,
        "threadId": record.thread,
        "message": record.message,
        "stack": record.stack_info,
    }
    return json_obj


def get_access_log(record):
    args = record.args
    json_obj = {
        "name": record.name,
        "level": record.levelname,
        "type": "access",
        "timestamp": record.asctime,
        "threadId": record.thread,
        "clientAddr": args[0],
        "method": args[1],
        "path": args[2],
        "httpVersion": args[3],
        "statusCode": args[4],
    }
    return json_obj


class AppLogJsonFormatter(logging.Formatter):
    def format(self, record):
        logging.Formatter.format(self, record)
        return json.dumps(get_app_log(record))


class AccessLogJsonFormatter(logging.Formatter):
    def format(self, record):
        logging.Formatter.format(self, record)
        return json.dumps(get_access_log(record))


def init_json_logger(level: str, loggers: list[Logger]):
    app_log_formatter = AppLogJsonFormatter("%(asctime)s")
    app_log_handler = logging.StreamHandler()
    app_log_handler.setFormatter(app_log_formatter)

    access_log_formatter = AccessLogJsonFormatter("%(asctime)s")
    access_log_handler = logging.StreamHandler()
    access_log_handler.setFormatter(access_log_formatter)

    logging.basicConfig(handlers=[app_log_handler], level=level)

    loggers_dict = {log.name: log.level for log in loggers}

    sql_logger = logging.getLogger("sqlalchemy.engine.Engine")
    sql_logger.handlers.clear()
    sql_logger.propagate = False
    sql_logger.addHandler(app_log_handler)
    sql_logger.setLevel(loggers_dict.get("sqlalchemy.engine.Engine", level))

    uvicorn_logger = logging.getLogger("uvicorn")
    uvicorn_logger.handlers.clear()
    uvicorn_logger.addHandler(app_log_handler)
    uvicorn_logger.propagate = False
    uvicorn_logger.setLevel(level)

    uvicorn_access_logger = logging.getLogger("uvicorn.access")
    uvicorn_access_logger.handlers.clear()
    uvicorn_access_logger.addHandler(access_log_handler)
    uvicorn_access_logger.propagate = False
    uvicorn_access_logger.setLevel(loggers_dict.get("uvicorn.access", level))

    gunicorn_logger = logging.getLogger("gunicorn")
    gunicorn_logger.handlers.clear()
    gunicorn_logger.addHandler(app_log_handler)
    gunicorn_logger.propagate = False
    gunicorn_logger.setLevel(level)


def init_plaintext_logger(level: str, loggers: list[Logger]):
    logging.basicConfig(level=level)

    loggers_dict = {log.name: log.level for log in loggers}

    sql_logger = logging.getLogger("sqlalchemy.engine.Engine")
    sql_logger.propagate = False
    sql_logger.setLevel(loggers_dict.get("sqlalchemy.engine.Engine", level))

    uvicorn_logger = logging.getLogger("uvicorn")
    uvicorn_logger.propagate = False
    uvicorn_logger.setLevel(level)

    uvicorn_logger = logging.getLogger("uvicorn.access")
    uvicorn_logger.propagate = False
    uvicorn_logger.setLevel(loggers_dict.get("uvicorn.access", level))

fastapi_sqlalchemy_v1/config.py

import re
from typing import Any, Callable

from pyaml_env import parse_config  # type: ignore
from pydantic import BaseModel


class Db(BaseModel):
    url: str
    show_sql: bool


class Cors(BaseModel):
    allow_origins: list[str]


class Logger(BaseModel):
    name: str
    level: str


class Log(BaseModel):
    level: str
    format: str
    loggers: list[Logger]


class Swagger(BaseModel):
    enabled: bool


class App(BaseModel):
    http_port: int
    scheme: str
    host: str


class Config(BaseModel):
    db: Db
    cors: Cors
    log: Log
    swagger: Swagger
    app: App


def camel_to_snake(s: str) -> str:
    return re.sub("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", s).lower()


def convert_dict_key(d: dict, conv: Callable[[str], str]) -> dict:
    def convert_value(v: Any) -> Any:
        return (
            convert_dict_key(v, conv)
            if isinstance(v, dict)
            else [convert_value(e) for e in v]
            if isinstance(v, list)
            else v
        )

    return {conv(k): convert_value(v) for k, v in d.items()}


class AppConfig:
    config: Config

    @staticmethod
    def init(config_text: str):
        config_dict_camel = parse_config(data=config_text)
        config_dict_snake = convert_dict_key(config_dict_camel, camel_to_snake)
        print(config_dict_snake)

        AppConfig.config = Config.parse_obj(config_dict_snake)

fastapi_sqlalchemy_v1/config.yml

---
db:
  url: sqlite:///dev.db.sqlite3
  showSql: !ENV ${DB_SHOW_SQL:false}
cors:
  allowOrigins:
    - !ENV ${CORS_ALLOW_ORIGIN:http://localhost:5173}
log:
  level: !ENV ${LOG_LEVEL:INFO}
  format: !ENV ${LOG_FORMAT:PLAIN}
  loggers:
    - name: uvicorn.access
      level: INFO
swagger:
  enabled: !ENV ${SWAGGER_ENABLED:true}
app:
  httpPort: !ENV ${APP_HTTP_PORT:8000}
  scheme: !ENV ${APP_SCHEMA:http}
  host: !ENV ${APP_HOST:localhost:8000}

fastapi_sqlalchemy_v1/healthcheck_controller.py

import logging
import traceback

from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session

from fastapi_sqlalchemy_v1.book import Book
from fastapi_sqlalchemy_v1.book_repository import (
    BookAddParameter,
    BookNotFoundError,
    BookRepository,
)
from fastapi_sqlalchemy_v1.config import AppConfig
from fastapi_sqlalchemy_v1.database import Database

logger = logging.getLogger(__name__)

router = APIRouter()


class BookResponseHTTPEntity(BaseModel):
    id: int = Field(...)
    name: str = Field(..., example="book")
    lang2: str = Field(..., example="ja")
    created_at: str = Field(...)

    @classmethod
    def new(self, book: Book) -> "BookResponseHTTPEntity":
        return BookResponseHTTPEntity(
            id=book.book_id,
            name=book.name,
            lang2=book.lang2,
            created_at=book.created_at.isoformat(),
        )


class BookAddParameterHTTPEntity(BaseModel):
    name: str = Field(..., min_length=1, max_length=20)
    lang2: str = Field(..., min_lengh=2, max_length=2, example="ja")


def book_repository(session: Session) -> BookRepository:
    return BookRepository(session)


@router.get("/book/{book_id}", response_model=BookResponseHTTPEntity, tags=["Book"])
def get_book(book_id: int, session: Session = Depends(Database.get_session)):
    logger.info("get_book")
    try:
        book_repo: BookRepository = book_repository(session)
        book: Book = book_repo.find_book_by_id(book_id)
        book_http_entity = BookResponseHTTPEntity.new(book)
        return book_http_entity
    except BookNotFoundError:
        raise HTTPException(status_code=404, detail="item_not_found")
    except Exception as e:
        logger.error(str(e))
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail="internal server error")


@router.post("/book", tags=["Book"], responses={201: {"model": None}})
def add_book(
    body: BookAddParameterHTTPEntity, session: Session = Depends(Database.get_session)
):
    logger.info("add_book")
    app_config = AppConfig.config.app
    try:
        book_add_param = BookAddParameter(name=body.name, lang2=body.lang2)
        book_repo: BookRepository = book_repository(session)
        book_id = book_repo.add_book(book_add_param)
        url = f"{app_config.scheme}://{app_config.host}/book/{book_id}"
        return Response(status_code=status.HTTP_201_CREATED, headers={"location": url})
    except Exception as e:
        logger.error(str(e))
        logger.error(traceback.format_exc())
        raise HTTPException(status_code=500, detail="internal server error")

fastapi_sqlalchemy_v1/const.py

ROOT_PACKAGE = "fastapi_sqlalchemy_v1"

fastapi_sqlalchemy_v1/main.py

import logging
import os
import sys

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from fastapi_sqlalchemy_v1 import book_controller, const, healthcheck_controller, log
from fastapi_sqlalchemy_v1.config import AppConfig
from fastapi_sqlalchemy_v1.database import Database

config_file_path = os.getenv("CONFIG_FILE_PATH")
if not config_file_path:
    print("CONFIG_FILE_PATH is not specified.")
    sys.exit(1)
with open(config_file_path, encoding="utf-8") as f:
    config_text = f.read()

AppConfig.init(config_text)

# logging
if AppConfig.config.log.format == "JSON":
    log.init_json_logger(AppConfig.config.log.level, AppConfig.config.log.loggers)
else:
    log.init_plaintext_logger(AppConfig.config.log.level, AppConfig.config.log.loggers)

logger = logging.getLogger(__name__)
# db
Database.init(db_url=AppConfig.config.db.url, echo=AppConfig.config.db.show_sql)

# app
if AppConfig.config.swagger.enabled:
    app = FastAPI()
else:
    app = FastAPI(docs_url=None, redoc_url=None)

app.add_middleware(
    CORSMiddleware,
    allow_origins=AppConfig.config.cors.allow_origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

app.include_router(book_controller.router)
app.include_router(healthcheck_controller.router)


def main():
    uvicorn.run(
        f"{const.ROOT_PACKAGE}.main:app",
        host="0.0.0.0",
        port=AppConfig.config.app.http_port,
        reload=True,
    )


if __name__ == "__main__":
    main()

テストコード

tests/conftest.py

import os

import pytest

from fastapi_sqlalchemy_v1.database import Database


@pytest.fixture(scope="session")
def test_db():
    print("SetUp")
    try:
        os.remove("test.db.sqlite3")
    except Exception as e:
        print(e)
    Database.create_database()
    yield Database._session_factory()
    print("TearDown")

tests/test_book_controller.py

from datetime import datetime

from fastapi.testclient import TestClient
from httpx import Response

from fastapi_sqlalchemy_v1.main import app

from tests.conftest import test_db

client = TestClient(app, raise_server_exceptions=False)


def test_add_book(test_db):
    # when
    # - register new data
    response: Response = client.post(
        "/book",
        json={"name": "NAME", "lang2": "ja"},
    )
    # then
    assert response.status_code == 201
    assert len(response.content) == 0
    assert response.headers["content-length"] == "0"
    assert response.headers["location"] == "http://localhost:8000/book/1"
    location = response.headers["location"]
    length = len("http://localhost:8000")
    relative_location = location[length:]
    print(relative_location)

    # when
    # - get registered data
    response = client.get(relative_location)

    # then
    assert response.status_code == 200
    json_content = response.json()
    assert json_content["name"] == "NAME"
    assert json_content["lang2"] == "ja"
    created_at_str = json_content["created_at"]
    created_at = datetime.fromisoformat(created_at_str)
    assert abs((datetime.utcnow() - created_at).total_seconds()) < 3


def test_add_book_internal_server_error(test_db, mocker):
    # given
    mocker.patch(
        "fastapi_sqlalchemy_v1.book_controller.book_repository", side_effect=Exception()
    )
    # when
    # - register new data
    response: Response = client.post(
        "/book",
        json={"name": "NAME", "lang2": "ja"},
    )
    # then
    assert response.status_code == 500
    assert response.json() == {"detail": "internal server error"}


def test_find_book_that_doesnt_exist(test_db):
    # when
    response = client.get("/book/999")
    # then
    assert response.status_code == 404
    assert response.json() == {"detail": "item_not_found"}


def test_find_book_internal_server_error(test_db, mocker):
    # given
    mocker.patch(
        "fastapi_sqlalchemy_v1.book_controller.book_repository", side_effect=Exception()
    )
    response = client.get("/book/1")
    assert response.status_code == 500
    assert response.json() == {"detail": "internal server error"}

tests/test_healthcheck_controller.py

from fastapi.testclient import TestClient

from fastapi_sqlalchemy_v1.main import app

client = TestClient(app, raise_server_exceptions=False)


def test_find_book_that_doesnt_exist(test_db):
    # when
    response = client.get("/health")
    # then
    assert response.status_code == 200
    assert response.headers["Content-Length"] == "0"

tests/test_config.yml

---
db:
  url: sqlite:///test.db.sqlite3
  showSql: true
cors:
  allowOrigins:
    - !ENV ${CORS_ALLOW_ORIGIN:http://localhost:5173}
log:
  level: !ENV ${LOG_LEVEL:INFO}
  format: !ENV ${LOG_FORMAT:PLAIN}
  loggers:
    - name: uvicorn.access
      level: INFO
swagger:
  enabled: !ENV ${SWAGGER_ENABLED:true}
app:
  httpPort: !ENV ${APP_HTTP_PORT:8000}
  scheme: http
  host: localhost:8000

以上です。