Python + FastAPI + SQLAlchemy + DependencyInjectorで実装するWebAPI

Python + FastAPI + SQLAlchemy + DependencyInjectorで実装するWebAPI

  • OS: Xubuntu
  • Python: 3.10.x
  • Shell: zsh

プロジェクト作成

環境用意

1
2
3
4
5
6
7
8
poetry new python-fastapi-sqlalchemy
cd python-fastapi-sqlalchemy
python -m venv venv
. venv/bin/activate
pip install poetry
deactivate
. venv/bin/activate
poetry install

コード

pyproject.toml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
[tool.poetry]
name = "fastapi-sqlalchemy-v1"
version = "0.1.0"
description = ""
authors = ["pecolynx <pecolynx@gmail.com>"]
readme = "README.md"
packages = [{include = "fastapi_sqlalchemy_v1"}]


[tool.poetry.dependencies]
python = "^3.10"
fastapi = "^0.95.0"
sqlalchemy = {extras = ["mypy"], version = "^2.0.7"}
uvicorn = "^0.21.1"
pyaml-env = "^1.2.1"


[tool.poetry.group.dev.dependencies]
isort = "^5.12.0"
ruff = "^0.0.257"
black = "^23.1.0"
flake8 = "^6.0.0"
mypy = "^1.1.1"


[tool.poetry.group.test.dependencies]
pytest = "^7.2.2"
httpx = "^0.23.3"
pytest-cov = "^4.0.0"
pytest-mock = "^3.10.0"


[tool.isort]
profile = "black"
skip_glob = ["*_pb2.py", "*_pb2_grpc.py"]


[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| venv
| alembic
)/
| .*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the project
| .*_pb2_grpc.py # exclude autogenerated Protocol Buffer files anywhere in the project
)
'''

[tool.ruff]
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"alembic",
]

[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"
exclude = [
'tests',
]

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

setup.cfg

1
2
3
4
5
6
7
[flake8]
max-line-length = 88
exclude =
.git,
__pycache__,
venv,
alembic,

Makefile

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
CMD:=poetry run
APP_NAME:=fastapi_sqlalchemy_v1
IMAGE_NAME:=fastapi-sqlalchemy-v1
DB_NAME:=dev.db.sqlite3
SHELL=/bin/bash

.PHONY: all
all:
$(CMD) isort .
$(CMD) black .
$(CMD) flake8 .
$(CMD) ruff check .
$(CMD) mypy .

.PHONY: format
format:
$(CMD) isort .
$(CMD) black .

.PHONY: type
type:
$(CMD) mypy .

.PHONY: check
check:
$(CMD) flake8 .
$(CMD) ruff check .

.PHONY: test
test:
CONFIG_FILE_PATH=tests/test_config.yml \
$(CMD) pytest --cov=$(APP_NAME) --cov-branch --cov-report=term --cov-report=xml -v -s

.PHONY: run
run:
CONFIG_FILE_PATH=$(APP_NAME)/config.yml \
$(CMD) python $(APP_NAME)/main.py

.PHONY: uvicorn
uvicorn:
LOG_FORMAT=JSON \
CONFIG_FILE_PATH=$(APP_NAME)/config.yml \
$(CMD) uvicorn $(APP_NAME).main:app --reload

fastapi_sqlalchemy_v1/book.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from datetime import datetime


class Book:
def __init__(self, book_id: int, name: str, lang2: str, created_at: datetime):
self.__book_id = book_id
self.__name = name
self.__lang2 = lang2
self.__created_at = created_at

@property
def book_id(self):
return self.__book_id

@property
def name(self):
return self.__name

@property
def lang2(self):
return self.__lang2

@property
def created_at(self):
return self.__created_at

fastapi_sqlalchemy_v1/book_repository.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import Optional

from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Integer, Text
from sqlalchemy.orm import Session
from sqlalchemy.sql import func

from fastapi_sqlalchemy_v1.book import Book
from fastapi_sqlalchemy_v1.database import Base


class BookNotFoundError(Exception):
def __init__(self, key: str, value: str):
self.key = key
self.value = value


class BookOtherError(Exception):
pass


class BookAddParameter(BaseModel):
name: str = Field(..., min_length=1, max_length=20)
lang2: str = Field(..., min_length=2, max_length=2)


class BookDbEntity(Base):
__tablename__ = "book"

id = Column(Integer, primary_key=True)
name = Column(Text, nullable=False)
lang2 = Column(Text, nullable=False)
created_at = Column(DateTime, nullable=False, server_default=func.now())

def to_model(self) -> Optional[Book]:
if self.id and self.name and self.lang2 and self.created_at:
return Book(
book_id=self.id,
name=self.name,
lang2=self.lang2,
created_at=self.created_at,
)
return None


class BookRepository:
def __init__(self, session: Session):
self.__session = session

def find_book_by_id(self, book_id: int) -> Book:
book_entity: BookDbEntity | None = (
self.__session.query(BookDbEntity)
.filter(BookDbEntity.id == book_id)
.first()
)
if book_entity is None:
raise BookNotFoundError("id", str(book_id))
book = book_entity.to_model()
if book:
return book
raise BookOtherError()

def add_book(self, book_add_param: BookAddParameter) -> int:
book_entity = BookDbEntity(name=book_add_param.name, lang2=book_add_param.lang2)
self.__session.add(book_entity)
self.__session.flush()
self.__session.refresh(book_entity)
if book_entity.id:
return book_entity.id
raise BookOtherError()

fastapi_sqlalchemy_v1/database.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import logging
from typing import Generator

from sqlalchemy import create_engine, orm
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, scoped_session

logger = logging.getLogger(__name__)

Base = declarative_base()


class Database:
_engine: Engine
_session_factory: scoped_session

@staticmethod
def init(db_url: str, echo: bool) -> None:
Database._engine = create_engine(db_url, echo=echo, pool_pre_ping=True)
Database._session_factory = orm.scoped_session(
orm.sessionmaker(
autocommit=False,
autoflush=False,
bind=Database._engine,
),
)

@staticmethod
def create_database() -> None:
Base.metadata.create_all(Database._engine) # type: ignore

@staticmethod
def get_session() -> Generator[Session, None, None]:
session: Session = Database._session_factory()
try:
yield session
logger.info("commit")
session.commit()
except Exception:
logger.exception("Session rollback because of exception")
session.rollback()
raise
finally:
logger.info("close")
session.close()

fastapi_sqlalchemy_v1/log.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import json
import logging
from logging import LogRecord

from fastapi_sqlalchemy_v1.config import Logger


def get_app_log(record: LogRecord):
json_obj = {
"name": record.name,
"level": record.levelname,
"type": "app",
"timestamp": record.asctime,
"pathname": record.pathname,
"line": record.lineno,
"threadId": record.thread,
"message": record.message,
"stack": record.stack_info,
}
return json_obj


def get_access_log(record):
args = record.args
json_obj = {
"name": record.name,
"level": record.levelname,
"type": "access",
"timestamp": record.asctime,
"threadId": record.thread,
"clientAddr": args[0],
"method": args[1],
"path": args[2],
"httpVersion": args[3],
"statusCode": args[4],
}
return json_obj


class AppLogJsonFormatter(logging.Formatter):
def format(self, record):
logging.Formatter.format(self, record)
return json.dumps(get_app_log(record))


class AccessLogJsonFormatter(logging.Formatter):
def format(self, record):
logging.Formatter.format(self, record)
return json.dumps(get_access_log(record))


def init_json_logger(level: str, loggers: list[Logger]):
app_log_formatter = AppLogJsonFormatter("%(asctime)s")
app_log_handler = logging.StreamHandler()
app_log_handler.setFormatter(app_log_formatter)

access_log_formatter = AccessLogJsonFormatter("%(asctime)s")
access_log_handler = logging.StreamHandler()
access_log_handler.setFormatter(access_log_formatter)

logging.basicConfig(handlers=[app_log_handler], level=level)

loggers_dict = {log.name: log.level for log in loggers}

sql_logger = logging.getLogger("sqlalchemy.engine.Engine")
sql_logger.handlers.clear()
sql_logger.propagate = False
sql_logger.addHandler(app_log_handler)
sql_logger.setLevel(loggers_dict.get("sqlalchemy.engine.Engine", level))

uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers.clear()
uvicorn_logger.addHandler(app_log_handler)
uvicorn_logger.propagate = False
uvicorn_logger.setLevel(level)

uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.handlers.clear()
uvicorn_access_logger.addHandler(access_log_handler)
uvicorn_access_logger.propagate = False
uvicorn_access_logger.setLevel(loggers_dict.get("uvicorn.access", level))

gunicorn_logger = logging.getLogger("gunicorn")
gunicorn_logger.handlers.clear()
gunicorn_logger.addHandler(app_log_handler)
gunicorn_logger.propagate = False
gunicorn_logger.setLevel(level)


def init_plaintext_logger(level: str, loggers: list[Logger]):
logging.basicConfig(level=level)

loggers_dict = {log.name: log.level for log in loggers}

sql_logger = logging.getLogger("sqlalchemy.engine.Engine")
sql_logger.propagate = False
sql_logger.setLevel(loggers_dict.get("sqlalchemy.engine.Engine", level))

uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.propagate = False
uvicorn_logger.setLevel(level)

uvicorn_logger = logging.getLogger("uvicorn.access")
uvicorn_logger.propagate = False
uvicorn_logger.setLevel(loggers_dict.get("uvicorn.access", level))

fastapi_sqlalchemy_v1/config.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import re
from typing import Any, Callable

from pyaml_env import parse_config # type: ignore
from pydantic import BaseModel


class Db(BaseModel):
url: str
show_sql: bool


class Cors(BaseModel):
allow_origins: list[str]


class Logger(BaseModel):
name: str
level: str


class Log(BaseModel):
level: str
format: str
loggers: list[Logger]


class Swagger(BaseModel):
enabled: bool


class App(BaseModel):
http_port: int
scheme: str
host: str


class Config(BaseModel):
db: Db
cors: Cors
log: Log
swagger: Swagger
app: App


def camel_to_snake(s: str) -> str:
return re.sub("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", s).lower()


def convert_dict_key(d: dict, conv: Callable[[str], str]) -> dict:
def convert_value(v: Any) -> Any:
return (
convert_dict_key(v, conv)
if isinstance(v, dict)
else [convert_value(e) for e in v]
if isinstance(v, list)
else v
)

return {conv(k): convert_value(v) for k, v in d.items()}


class AppConfig:
config: Config

@staticmethod
def init(config_text: str):
config_dict_camel = parse_config(data=config_text)
config_dict_snake = convert_dict_key(config_dict_camel, camel_to_snake)
print(config_dict_snake)

AppConfig.config = Config.parse_obj(config_dict_snake)

fastapi_sqlalchemy_v1/config.yml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
---
db:
url: sqlite:///dev.db.sqlite3
showSql: !ENV ${DB_SHOW_SQL:false}
cors:
allowOrigins:
- !ENV ${CORS_ALLOW_ORIGIN:http://localhost:5173}
log:
level: !ENV ${LOG_LEVEL:INFO}
format: !ENV ${LOG_FORMAT:PLAIN}
loggers:
- name: uvicorn.access
level: INFO
swagger:
enabled: !ENV ${SWAGGER_ENABLED:true}
app:
httpPort: !ENV ${APP_HTTP_PORT:8000}
scheme: !ENV ${APP_SCHEMA:http}
host: !ENV ${APP_HOST:localhost:8000}

fastapi_sqlalchemy_v1/healthcheck_controller.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import logging
import traceback

from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session

from fastapi_sqlalchemy_v1.book import Book
from fastapi_sqlalchemy_v1.book_repository import (
BookAddParameter,
BookNotFoundError,
BookRepository,
)
from fastapi_sqlalchemy_v1.config import AppConfig
from fastapi_sqlalchemy_v1.database import Database

logger = logging.getLogger(__name__)

router = APIRouter()


class BookResponseHTTPEntity(BaseModel):
id: int = Field(...)
name: str = Field(..., example="book")
lang2: str = Field(..., example="ja")
created_at: str = Field(...)

@classmethod
def new(self, book: Book) -> "BookResponseHTTPEntity":
return BookResponseHTTPEntity(
id=book.book_id,
name=book.name,
lang2=book.lang2,
created_at=book.created_at.isoformat(),
)


class BookAddParameterHTTPEntity(BaseModel):
name: str = Field(..., min_length=1, max_length=20)
lang2: str = Field(..., min_lengh=2, max_length=2, example="ja")


def book_repository(session: Session) -> BookRepository:
return BookRepository(session)


@router.get("/book/{book_id}", response_model=BookResponseHTTPEntity, tags=["Book"])
def get_book(book_id: int, session: Session = Depends(Database.get_session)):
logger.info("get_book")
try:
book_repo: BookRepository = book_repository(session)
book: Book = book_repo.find_book_by_id(book_id)
book_http_entity = BookResponseHTTPEntity.new(book)
return book_http_entity
except BookNotFoundError:
raise HTTPException(status_code=404, detail="item_not_found")
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="internal server error")


@router.post("/book", tags=["Book"], responses={201: {"model": None}})
def add_book(
body: BookAddParameterHTTPEntity, session: Session = Depends(Database.get_session)
):
logger.info("add_book")
app_config = AppConfig.config.app
try:
book_add_param = BookAddParameter(name=body.name, lang2=body.lang2)
book_repo: BookRepository = book_repository(session)
book_id = book_repo.add_book(book_add_param)
url = f"{app_config.scheme}://{app_config.host}/book/{book_id}"
return Response(status_code=status.HTTP_201_CREATED, headers={"location": url})
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="internal server error")

fastapi_sqlalchemy_v1/const.py

1
ROOT_PACKAGE = "fastapi_sqlalchemy_v1"

fastapi_sqlalchemy_v1/main.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import logging
import os
import sys

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

from fastapi_sqlalchemy_v1 import book_controller, const, healthcheck_controller, log
from fastapi_sqlalchemy_v1.config import AppConfig
from fastapi_sqlalchemy_v1.database import Database

config_file_path = os.getenv("CONFIG_FILE_PATH")
if not config_file_path:
print("CONFIG_FILE_PATH is not specified.")
sys.exit(1)
with open(config_file_path, encoding="utf-8") as f:
config_text = f.read()

AppConfig.init(config_text)

# logging
if AppConfig.config.log.format == "JSON":
log.init_json_logger(AppConfig.config.log.level, AppConfig.config.log.loggers)
else:
log.init_plaintext_logger(AppConfig.config.log.level, AppConfig.config.log.loggers)

logger = logging.getLogger(__name__)
# db
Database.init(db_url=AppConfig.config.db.url, echo=AppConfig.config.db.show_sql)

# app
if AppConfig.config.swagger.enabled:
app = FastAPI()
else:
app = FastAPI(docs_url=None, redoc_url=None)

app.add_middleware(
CORSMiddleware,
allow_origins=AppConfig.config.cors.allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

app.include_router(book_controller.router)
app.include_router(healthcheck_controller.router)


def main():
uvicorn.run(
f"{const.ROOT_PACKAGE}.main:app",
host="0.0.0.0",
port=AppConfig.config.app.http_port,
reload=True,
)


if __name__ == "__main__":
main()

テストコード

tests/conftest.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import os

import pytest

from fastapi_sqlalchemy_v1.database import Database


@pytest.fixture(scope="session")
def test_db():
print("SetUp")
try:
os.remove("test.db.sqlite3")
except Exception as e:
print(e)
Database.create_database()
yield Database._session_factory()
print("TearDown")

tests/test_book_controller.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from datetime import datetime

from fastapi.testclient import TestClient
from httpx import Response

from fastapi_sqlalchemy_v1.main import app

from tests.conftest import test_db

client = TestClient(app, raise_server_exceptions=False)


def test_add_book(test_db):
# when
# - register new data
response: Response = client.post(
"/book",
json={"name": "NAME", "lang2": "ja"},
)
# then
assert response.status_code == 201
assert len(response.content) == 0
assert response.headers["content-length"] == "0"
assert response.headers["location"] == "http://localhost:8000/book/1"
location = response.headers["location"]
length = len("http://localhost:8000")
relative_location = location[length:]
print(relative_location)

# when
# - get registered data
response = client.get(relative_location)

# then
assert response.status_code == 200
json_content = response.json()
assert json_content["name"] == "NAME"
assert json_content["lang2"] == "ja"
created_at_str = json_content["created_at"]
created_at = datetime.fromisoformat(created_at_str)
assert abs((datetime.utcnow() - created_at).total_seconds()) < 3


def test_add_book_internal_server_error(test_db, mocker):
# given
mocker.patch(
"fastapi_sqlalchemy_v1.book_controller.book_repository", side_effect=Exception()
)
# when
# - register new data
response: Response = client.post(
"/book",
json={"name": "NAME", "lang2": "ja"},
)
# then
assert response.status_code == 500
assert response.json() == {"detail": "internal server error"}


def test_find_book_that_doesnt_exist(test_db):
# when
response = client.get("/book/999")
# then
assert response.status_code == 404
assert response.json() == {"detail": "item_not_found"}


def test_find_book_internal_server_error(test_db, mocker):
# given
mocker.patch(
"fastapi_sqlalchemy_v1.book_controller.book_repository", side_effect=Exception()
)
response = client.get("/book/1")
assert response.status_code == 500
assert response.json() == {"detail": "internal server error"}

tests/test_healthcheck_controller.py

1
2
3
4
5
6
7
8
9
10
11
12
13
from fastapi.testclient import TestClient

from fastapi_sqlalchemy_v1.main import app

client = TestClient(app, raise_server_exceptions=False)


def test_find_book_that_doesnt_exist(test_db):
# when
response = client.get("/health")
# then
assert response.status_code == 200
assert response.headers["Content-Length"] == "0"

tests/test_config.yml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
---
db:
url: sqlite:///test.db.sqlite3
showSql: true
cors:
allowOrigins:
- !ENV ${CORS_ALLOW_ORIGIN:http://localhost:5173}
log:
level: !ENV ${LOG_LEVEL:INFO}
format: !ENV ${LOG_FORMAT:PLAIN}
loggers:
- name: uvicorn.access
level: INFO
swagger:
enabled: !ENV ${SWAGGER_ENABLED:true}
app:
httpPort: !ENV ${APP_HTTP_PORT:8000}
scheme: http
host: localhost:8000

以上です。