Python + FastAPI + SQLAlchemy + DependencyInjectorで実装するWebAPI
- OS: Xubuntu
- Python: 3.10.x
- Shell: zsh
プロジェクト作成
環境用意
poetry new python-fastapi-sqlalchemy
cd python-fastapi-sqlalchemy
python -m venv venv
. venv/bin/activate
pip install poetry
deactivate
. venv/bin/activate
poetry install
コード
pyproject.toml
[tool.poetry]
name = "fastapi-sqlalchemy-v1"
version = "0.1.0"
description = ""
authors = ["pecolynx <pecolynx@gmail.com>"]
readme = "README.md"
packages = [{include = "fastapi_sqlalchemy_v1"}]
[tool.poetry.dependencies]
python = "^3.10"
fastapi = "^0.95.0"
sqlalchemy = {extras = ["mypy"], version = "^2.0.7"}
uvicorn = "^0.21.1"
pyaml-env = "^1.2.1"
[tool.poetry.group.dev.dependencies]
isort = "^5.12.0"
ruff = "^0.0.257"
black = "^23.1.0"
flake8 = "^6.0.0"
mypy = "^1.1.1"
[tool.poetry.group.test.dependencies]
pytest = "^7.2.2"
httpx = "^0.23.3"
pytest-cov = "^4.0.0"
pytest-mock = "^3.10.0"
[tool.isort]
profile = "black"
skip_glob = ["*_pb2.py", "*_pb2_grpc.py"]
[tool.black]
line-length = 88
include = '\.pyi?$'
exclude = '''
(
/(
\.eggs # exclude a few common directories in the
| \.git # root of the project
| \.hg
| \.mypy_cache
| \.tox
| \.venv
| _build
| buck-out
| build
| dist
| venv
| alembic
)/
| .*_pb2.py # exclude autogenerated Protocol Buffer files anywhere in the project
| .*_pb2_grpc.py # exclude autogenerated Protocol Buffer files anywhere in the project
)
'''
[tool.ruff]
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".pytype",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
"alembic",
]
[tool.mypy]
plugins = "sqlalchemy.ext.mypy.plugin"
exclude = [
'tests',
]
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
setup.cfg
[flake8]
max-line-length = 88
exclude =
.git,
__pycache__,
venv,
alembic,
Makefile
CMD:=poetry run
APP_NAME:=fastapi_sqlalchemy_v1
IMAGE_NAME:=fastapi-sqlalchemy-v1
DB_NAME:=dev.db.sqlite3
SHELL=/bin/bash
.PHONY: all
all:
$(CMD) isort .
$(CMD) black .
$(CMD) flake8 .
$(CMD) ruff check .
$(CMD) mypy .
.PHONY: format
format:
$(CMD) isort .
$(CMD) black .
.PHONY: type
type:
$(CMD) mypy .
.PHONY: check
check:
$(CMD) flake8 .
$(CMD) ruff check .
.PHONY: test
test:
CONFIG_FILE_PATH=tests/test_config.yml \
$(CMD) pytest --cov=$(APP_NAME) --cov-branch --cov-report=term --cov-report=xml -v -s
.PHONY: run
run:
CONFIG_FILE_PATH=$(APP_NAME)/config.yml \
$(CMD) python $(APP_NAME)/main.py
.PHONY: uvicorn
uvicorn:
LOG_FORMAT=JSON \
CONFIG_FILE_PATH=$(APP_NAME)/config.yml \
$(CMD) uvicorn $(APP_NAME).main:app --reload
fastapi_sqlalchemy_v1/book.py
from datetime import datetime
class Book:
def __init__(self, book_id: int, name: str, lang2: str, created_at: datetime):
self.__book_id = book_id
self.__name = name
self.__lang2 = lang2
self.__created_at = created_at
@property
def book_id(self):
return self.__book_id
@property
def name(self):
return self.__name
@property
def lang2(self):
return self.__lang2
@property
def created_at(self):
return self.__created_at
fastapi_sqlalchemy_v1/book_repository.py
from typing import Optional
from pydantic import BaseModel, Field
from sqlalchemy import Column, DateTime, Integer, Text
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
from fastapi_sqlalchemy_v1.book import Book
from fastapi_sqlalchemy_v1.database import Base
class BookNotFoundError(Exception):
def __init__(self, key: str, value: str):
self.key = key
self.value = value
class BookOtherError(Exception):
pass
class BookAddParameter(BaseModel):
name: str = Field(..., min_length=1, max_length=20)
lang2: str = Field(..., min_length=2, max_length=2)
class BookDbEntity(Base):
__tablename__ = "book"
id = Column(Integer, primary_key=True)
name = Column(Text, nullable=False)
lang2 = Column(Text, nullable=False)
created_at = Column(DateTime, nullable=False, server_default=func.now())
def to_model(self) -> Optional[Book]:
if self.id and self.name and self.lang2 and self.created_at:
return Book(
book_id=self.id,
name=self.name,
lang2=self.lang2,
created_at=self.created_at,
)
return None
class BookRepository:
def __init__(self, session: Session):
self.__session = session
def find_book_by_id(self, book_id: int) -> Book:
book_entity: BookDbEntity | None = (
self.__session.query(BookDbEntity)
.filter(BookDbEntity.id == book_id)
.first()
)
if book_entity is None:
raise BookNotFoundError("id", str(book_id))
book = book_entity.to_model()
if book:
return book
raise BookOtherError()
def add_book(self, book_add_param: BookAddParameter) -> int:
book_entity = BookDbEntity(name=book_add_param.name, lang2=book_add_param.lang2)
self.__session.add(book_entity)
self.__session.flush()
self.__session.refresh(book_entity)
if book_entity.id:
return book_entity.id
raise BookOtherError()
fastapi_sqlalchemy_v1/database.py
import logging
from typing import Generator
from sqlalchemy import create_engine, orm
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session, declarative_base, scoped_session
logger = logging.getLogger(__name__)
Base = declarative_base()
class Database:
_engine: Engine
_session_factory: scoped_session
@staticmethod
def init(db_url: str, echo: bool) -> None:
Database._engine = create_engine(db_url, echo=echo, pool_pre_ping=True)
Database._session_factory = orm.scoped_session(
orm.sessionmaker(
autocommit=False,
autoflush=False,
bind=Database._engine,
),
)
@staticmethod
def create_database() -> None:
Base.metadata.create_all(Database._engine) # type: ignore
@staticmethod
def get_session() -> Generator[Session, None, None]:
session: Session = Database._session_factory()
try:
yield session
logger.info("commit")
session.commit()
except Exception:
logger.exception("Session rollback because of exception")
session.rollback()
raise
finally:
logger.info("close")
session.close()
fastapi_sqlalchemy_v1/log.py
import json
import logging
from logging import LogRecord
from fastapi_sqlalchemy_v1.config import Logger
def get_app_log(record: LogRecord):
json_obj = {
"name": record.name,
"level": record.levelname,
"type": "app",
"timestamp": record.asctime,
"pathname": record.pathname,
"line": record.lineno,
"threadId": record.thread,
"message": record.message,
"stack": record.stack_info,
}
return json_obj
def get_access_log(record):
args = record.args
json_obj = {
"name": record.name,
"level": record.levelname,
"type": "access",
"timestamp": record.asctime,
"threadId": record.thread,
"clientAddr": args[0],
"method": args[1],
"path": args[2],
"httpVersion": args[3],
"statusCode": args[4],
}
return json_obj
class AppLogJsonFormatter(logging.Formatter):
def format(self, record):
logging.Formatter.format(self, record)
return json.dumps(get_app_log(record))
class AccessLogJsonFormatter(logging.Formatter):
def format(self, record):
logging.Formatter.format(self, record)
return json.dumps(get_access_log(record))
def init_json_logger(level: str, loggers: list[Logger]):
app_log_formatter = AppLogJsonFormatter("%(asctime)s")
app_log_handler = logging.StreamHandler()
app_log_handler.setFormatter(app_log_formatter)
access_log_formatter = AccessLogJsonFormatter("%(asctime)s")
access_log_handler = logging.StreamHandler()
access_log_handler.setFormatter(access_log_formatter)
logging.basicConfig(handlers=[app_log_handler], level=level)
loggers_dict = {log.name: log.level for log in loggers}
sql_logger = logging.getLogger("sqlalchemy.engine.Engine")
sql_logger.handlers.clear()
sql_logger.propagate = False
sql_logger.addHandler(app_log_handler)
sql_logger.setLevel(loggers_dict.get("sqlalchemy.engine.Engine", level))
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.handlers.clear()
uvicorn_logger.addHandler(app_log_handler)
uvicorn_logger.propagate = False
uvicorn_logger.setLevel(level)
uvicorn_access_logger = logging.getLogger("uvicorn.access")
uvicorn_access_logger.handlers.clear()
uvicorn_access_logger.addHandler(access_log_handler)
uvicorn_access_logger.propagate = False
uvicorn_access_logger.setLevel(loggers_dict.get("uvicorn.access", level))
gunicorn_logger = logging.getLogger("gunicorn")
gunicorn_logger.handlers.clear()
gunicorn_logger.addHandler(app_log_handler)
gunicorn_logger.propagate = False
gunicorn_logger.setLevel(level)
def init_plaintext_logger(level: str, loggers: list[Logger]):
logging.basicConfig(level=level)
loggers_dict = {log.name: log.level for log in loggers}
sql_logger = logging.getLogger("sqlalchemy.engine.Engine")
sql_logger.propagate = False
sql_logger.setLevel(loggers_dict.get("sqlalchemy.engine.Engine", level))
uvicorn_logger = logging.getLogger("uvicorn")
uvicorn_logger.propagate = False
uvicorn_logger.setLevel(level)
uvicorn_logger = logging.getLogger("uvicorn.access")
uvicorn_logger.propagate = False
uvicorn_logger.setLevel(loggers_dict.get("uvicorn.access", level))
fastapi_sqlalchemy_v1/config.py
import re
from typing import Any, Callable
from pyaml_env import parse_config # type: ignore
from pydantic import BaseModel
class Db(BaseModel):
url: str
show_sql: bool
class Cors(BaseModel):
allow_origins: list[str]
class Logger(BaseModel):
name: str
level: str
class Log(BaseModel):
level: str
format: str
loggers: list[Logger]
class Swagger(BaseModel):
enabled: bool
class App(BaseModel):
http_port: int
scheme: str
host: str
class Config(BaseModel):
db: Db
cors: Cors
log: Log
swagger: Swagger
app: App
def camel_to_snake(s: str) -> str:
return re.sub("((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", s).lower()
def convert_dict_key(d: dict, conv: Callable[[str], str]) -> dict:
def convert_value(v: Any) -> Any:
return (
convert_dict_key(v, conv)
if isinstance(v, dict)
else [convert_value(e) for e in v]
if isinstance(v, list)
else v
)
return {conv(k): convert_value(v) for k, v in d.items()}
class AppConfig:
config: Config
@staticmethod
def init(config_text: str):
config_dict_camel = parse_config(data=config_text)
config_dict_snake = convert_dict_key(config_dict_camel, camel_to_snake)
print(config_dict_snake)
AppConfig.config = Config.parse_obj(config_dict_snake)
fastapi_sqlalchemy_v1/config.yml
---
db:
url: sqlite:///dev.db.sqlite3
showSql: !ENV ${DB_SHOW_SQL:false}
cors:
allowOrigins:
- !ENV ${CORS_ALLOW_ORIGIN:http://localhost:5173}
log:
level: !ENV ${LOG_LEVEL:INFO}
format: !ENV ${LOG_FORMAT:PLAIN}
loggers:
- name: uvicorn.access
level: INFO
swagger:
enabled: !ENV ${SWAGGER_ENABLED:true}
app:
httpPort: !ENV ${APP_HTTP_PORT:8000}
scheme: !ENV ${APP_SCHEMA:http}
host: !ENV ${APP_HOST:localhost:8000}
fastapi_sqlalchemy_v1/healthcheck_controller.py
import logging
import traceback
from fastapi import APIRouter, Depends, HTTPException, Response, status
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from fastapi_sqlalchemy_v1.book import Book
from fastapi_sqlalchemy_v1.book_repository import (
BookAddParameter,
BookNotFoundError,
BookRepository,
)
from fastapi_sqlalchemy_v1.config import AppConfig
from fastapi_sqlalchemy_v1.database import Database
logger = logging.getLogger(__name__)
router = APIRouter()
class BookResponseHTTPEntity(BaseModel):
id: int = Field(...)
name: str = Field(..., example="book")
lang2: str = Field(..., example="ja")
created_at: str = Field(...)
@classmethod
def new(self, book: Book) -> "BookResponseHTTPEntity":
return BookResponseHTTPEntity(
id=book.book_id,
name=book.name,
lang2=book.lang2,
created_at=book.created_at.isoformat(),
)
class BookAddParameterHTTPEntity(BaseModel):
name: str = Field(..., min_length=1, max_length=20)
lang2: str = Field(..., min_lengh=2, max_length=2, example="ja")
def book_repository(session: Session) -> BookRepository:
return BookRepository(session)
@router.get("/book/{book_id}", response_model=BookResponseHTTPEntity, tags=["Book"])
def get_book(book_id: int, session: Session = Depends(Database.get_session)):
logger.info("get_book")
try:
book_repo: BookRepository = book_repository(session)
book: Book = book_repo.find_book_by_id(book_id)
book_http_entity = BookResponseHTTPEntity.new(book)
return book_http_entity
except BookNotFoundError:
raise HTTPException(status_code=404, detail="item_not_found")
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="internal server error")
@router.post("/book", tags=["Book"], responses={201: {"model": None}})
def add_book(
body: BookAddParameterHTTPEntity, session: Session = Depends(Database.get_session)
):
logger.info("add_book")
app_config = AppConfig.config.app
try:
book_add_param = BookAddParameter(name=body.name, lang2=body.lang2)
book_repo: BookRepository = book_repository(session)
book_id = book_repo.add_book(book_add_param)
url = f"{app_config.scheme}://{app_config.host}/book/{book_id}"
return Response(status_code=status.HTTP_201_CREATED, headers={"location": url})
except Exception as e:
logger.error(str(e))
logger.error(traceback.format_exc())
raise HTTPException(status_code=500, detail="internal server error")
fastapi_sqlalchemy_v1/const.py
ROOT_PACKAGE = "fastapi_sqlalchemy_v1"
fastapi_sqlalchemy_v1/main.py
import logging
import os
import sys
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi_sqlalchemy_v1 import book_controller, const, healthcheck_controller, log
from fastapi_sqlalchemy_v1.config import AppConfig
from fastapi_sqlalchemy_v1.database import Database
config_file_path = os.getenv("CONFIG_FILE_PATH")
if not config_file_path:
print("CONFIG_FILE_PATH is not specified.")
sys.exit(1)
with open(config_file_path, encoding="utf-8") as f:
config_text = f.read()
AppConfig.init(config_text)
# logging
if AppConfig.config.log.format == "JSON":
log.init_json_logger(AppConfig.config.log.level, AppConfig.config.log.loggers)
else:
log.init_plaintext_logger(AppConfig.config.log.level, AppConfig.config.log.loggers)
logger = logging.getLogger(__name__)
# db
Database.init(db_url=AppConfig.config.db.url, echo=AppConfig.config.db.show_sql)
# app
if AppConfig.config.swagger.enabled:
app = FastAPI()
else:
app = FastAPI(docs_url=None, redoc_url=None)
app.add_middleware(
CORSMiddleware,
allow_origins=AppConfig.config.cors.allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(book_controller.router)
app.include_router(healthcheck_controller.router)
def main():
uvicorn.run(
f"{const.ROOT_PACKAGE}.main:app",
host="0.0.0.0",
port=AppConfig.config.app.http_port,
reload=True,
)
if __name__ == "__main__":
main()
テストコード
tests/conftest.py
import os
import pytest
from fastapi_sqlalchemy_v1.database import Database
@pytest.fixture(scope="session")
def test_db():
print("SetUp")
try:
os.remove("test.db.sqlite3")
except Exception as e:
print(e)
Database.create_database()
yield Database._session_factory()
print("TearDown")
tests/test_book_controller.py
from datetime import datetime
from fastapi.testclient import TestClient
from httpx import Response
from fastapi_sqlalchemy_v1.main import app
from tests.conftest import test_db
client = TestClient(app, raise_server_exceptions=False)
def test_add_book(test_db):
# when
# - register new data
response: Response = client.post(
"/book",
json={"name": "NAME", "lang2": "ja"},
)
# then
assert response.status_code == 201
assert len(response.content) == 0
assert response.headers["content-length"] == "0"
assert response.headers["location"] == "http://localhost:8000/book/1"
location = response.headers["location"]
length = len("http://localhost:8000")
relative_location = location[length:]
print(relative_location)
# when
# - get registered data
response = client.get(relative_location)
# then
assert response.status_code == 200
json_content = response.json()
assert json_content["name"] == "NAME"
assert json_content["lang2"] == "ja"
created_at_str = json_content["created_at"]
created_at = datetime.fromisoformat(created_at_str)
assert abs((datetime.utcnow() - created_at).total_seconds()) < 3
def test_add_book_internal_server_error(test_db, mocker):
# given
mocker.patch(
"fastapi_sqlalchemy_v1.book_controller.book_repository", side_effect=Exception()
)
# when
# - register new data
response: Response = client.post(
"/book",
json={"name": "NAME", "lang2": "ja"},
)
# then
assert response.status_code == 500
assert response.json() == {"detail": "internal server error"}
def test_find_book_that_doesnt_exist(test_db):
# when
response = client.get("/book/999")
# then
assert response.status_code == 404
assert response.json() == {"detail": "item_not_found"}
def test_find_book_internal_server_error(test_db, mocker):
# given
mocker.patch(
"fastapi_sqlalchemy_v1.book_controller.book_repository", side_effect=Exception()
)
response = client.get("/book/1")
assert response.status_code == 500
assert response.json() == {"detail": "internal server error"}
tests/test_healthcheck_controller.py
from fastapi.testclient import TestClient
from fastapi_sqlalchemy_v1.main import app
client = TestClient(app, raise_server_exceptions=False)
def test_find_book_that_doesnt_exist(test_db):
# when
response = client.get("/health")
# then
assert response.status_code == 200
assert response.headers["Content-Length"] == "0"
tests/test_config.yml
---
db:
url: sqlite:///test.db.sqlite3
showSql: true
cors:
allowOrigins:
- !ENV ${CORS_ALLOW_ORIGIN:http://localhost:5173}
log:
level: !ENV ${LOG_LEVEL:INFO}
format: !ENV ${LOG_FORMAT:PLAIN}
loggers:
- name: uvicorn.access
level: INFO
swagger:
enabled: !ENV ${SWAGGER_ENABLED:true}
app:
httpPort: !ENV ${APP_HTTP_PORT:8000}
scheme: http
host: localhost:8000
以上です。