I'm working on learning how to unit test properly. Given this function...
def get_user_details(req_user_id):users = sa.Table('users', db.metadata, autoload=True)s = sa.select([users.c.username,users.c.favorite_color,users.c.favorite_book]).select_from(users)s = s.where(users.c.user_id == req_user_id)result = db.connection.execute(s).fetchone()return dict(result)
...what is the proper way to unit test it?
Here's where I am now...
- From what I've read, testing the "construct" of the query is unnecessary as that's part of the already well-tested SQLAlchemy library. So I don't need to test the raw SQL generated, right? But should I test the parameters passed and if so, how?
- I've read about mocking the response that comes from the
db.connection.execute
but how is that really testing anything? Ultimately, I want to make sure the function is generating the proper SQL and getting the right database result?
Any advice/guidance is much appreciated. Thank you!
Following from this comment:
what you need to test is if the statements in your code produce
expected results. – Shod
and your code, here's what I'm aiming to answer:
How can I test a method that is dynamically generating SQLAlchemy queries ?
That is the issue I was having, as I wanted to make sure the queries generated were indeed correct - not because of SQL Alchemy, but because of the logic that puts them together.
the code we will test
def add_filters(query, target_table, filter_type: str, filter_value: str):if filter_type == "favourite_book":query = query.filter(target_table.c.favourite_book == filter_value)elif filter_type == "favourite_color":query = query.filter(target_table.c.favourite_color == filter_value)return query
and so I want to test that indeed the favourite_book filter is correctly added to the query.
In order to do this, we will create a temporary sqlite3 database, with a table containing data, and run the queries against it. Finally we test the result of the query. NOTE: you'll want both good and bad data in the data for comprehensive testing.
setting up the testing database
import pytest
from sqlalchemy.orm import Session
from sqlalchemy import create_engine, MetaData, Table, Column, Integer, String
import pandas as pd@pytest.fixture
def engine():my_engine = create_engine("sqlite:///file:test_db?mode=memory&cache=shared&uri=true", echo=True)return my_engine@pytest.fixture
def target_table(engine):meta = MetaData()table_name = "users"users = Table(table_name,meta,Column("id", Integer, primary_key=True),Column("username", String),Column("favourite_color", String),Column("favourite_book", String),)meta.create_all(engine)# you can choose to skip the whole table declaration as df.to_sql will create# the table for you if it doesn't existrecords = [{"username": "john", "favourite_color": "blue", "favourite_book": "Harry Potter"},{"username": "jane", "favourite_color": "red", "favourite_book": "The Power of Now"},{"username": "bob", "favourite_color": "green", "favourite_book": "Extreme Ownership"},]df = pd.DataFrame(records)df.to_sql(table_name, engine, if_exists="append", index=False)return users
And finally the actual test
def test_query(engine, target_table):with Session(engine) as session:query = session.query(target_table)query = add_filters(query, target_table, "favourite_book", "Extreme Ownership")df = pd.read_sql_query(query.statement, session.bind)assert df["favourite_book"].unique().tolist() == ["Extreme Ownership"]
and you can see, the test is not very comprehensive, as it only tests one case.
However, we can use pytest.mark.parametrize to extend that. (see last reference)
@pytest.mark.parametrize("filter_type,filter_value,expected_result",[("favourite_book", "Extreme Ownership", True),("favourite_book", "Extreme", False),("favourite_color", "blue", True),("favourite_color", "purple", False),],
)
def test_query(engine, target_table, filter_type, filter_value, expected_result):with Session(engine) as session:query = session.query(target_table)query = add_filters(query, target_table, filter_type, filter_value)df = pd.read_sql_query(query.statement, session.bind)assert (df[filter_type].unique().tolist() == [filter_value]) == expected_result
some references:
- https://www.tutorialspoint.com/pytest/pytest_fixtures.htm
- https://smirnov-am.github.io/pytest-testing_database/
- Python - How to connect SQLAlchemy to existing database in memory
- Mocking database calls in python using pytest-mock