Patching Guide
Overview
This page is a guide on how to use the @patch functionality to enable your tests to run on GitHub after you have added a database.
Setup
Here is my example DataSource:
import records
import ProductionCode.psql_config as config
class DataSource:
def __init__(self):
self.database_url = f"postgresql://{config.USER}:{config.PASSWORD}@{config.HOST}:5432/{config.DATABASE}"
self.database = records.Database(self.database_url)
def get_pokemon_by_name(self, name):
query = "SELECT * FROM pokemon WHERE name = :name"
rows = self.database.query(query, name=name)
if rows:
return rows[0].export('csv')
return None
def get_pokemon_by_name_dictionary(self, name):
#an example if you just want to return the Record object dictionary
query = "SELECT * FROM pokemon WHERE name = :name"
rows = self.database.query(query, name=name)
return rows
def get_pokemon_by_stat(self, stat, count):
'''
Returns the top number of Pokemon by the given stat
An example of returning multiple rows with an export
:param stat: The stat to use
:param count: The number of Pokemon to display
:return: A list of Pokemon
'''
#sort the data by the given stat
rows = self.database.query(f"SELECT * FROM pokemon ORDER BY {stat} DESC LIMIT {count}")
return rows.export('csv')
Testing get_pokemon_by_name
My method doesn’t do a whole lot, but if I wanted to test the Python code (not the query), here is how.
The unittest.mock.patch tool allows you to swap out the real records.Database with a Mock object. This lets you “fake” the database response and verify that your logic — like calling .export('csv') — is working correctly without needing to connect to the real server.
The Strategy: Mocking the Chain
In my get_pokemon_by_name method, there is a chain of calls that I need to simulate:
records.Database()creates the database object.database.query()returns a list-like object of rows.rows[0]is a row object.rows[0].export('csv')returns the final string.
The Test Implementation
Assuming my code is in a file named datasource.py in ProductionCode, here is how I write the test:
import unittest
from unittest.mock import patch, MagicMock
from ProductionCode.datasource import DataSource
class TestDataSource(unittest.TestCase):
@patch('ProductionCode.datasource.records.Database')
def test_get_pokemon_by_name_returns_csv(self, mock_db_class):
# The combo of @patch and the parameter mock_db_class automatically
# creates a MagicMock of records.Database and saves it to mock_db_class
# 1. Setup the Mocks
# mock_db_class is the already mocked 'records.Database' class (really its constructor)
# mock_db_class.return_value is a new mock object, because mock object constructors make
# more mock objects
# mock_db_instance is the object that will be returned when 'records.Database()' is called
mock_db_instance = mock_db_class.return_value
# Create a mock for the individual row, MagicMock makes an object with all the normal built-in methods
# and lets you add additional things to it
mock_row = MagicMock()
# When export is called on the mock_row, regardless of parameter, this string is returned
# this works because all functions/methods are actually objects in Python
mock_row.export.return_value = "1,Pikachu,Electric"
# Mock the .query() method to return a list containing our mock_row
mock_db_instance.query.return_value = [mock_row]
# 2. Initialize a fresh DataSource for this test
ds = DataSource()
# Run the code we are testing
result = ds.get_pokemon_by_name("Pikachu")
# 3. Assertions
# Verify the database was queried with the correct parameters
mock_db_instance.query.assert_called_once_with(
"SELECT * FROM pokemon WHERE name = :name",
name="Pikachu"
)
# Verify the row's export method was called
mock_row.export.assert_called_with('csv')
# Verify the final output
self.assertEqual(result, "1,Pikachu,Electric")
@patch('ProductionCode.datasource.records.Database')
def test_get_pokemon_by_name_dictionary(self, mock_db_class):
# Setup the mock database instance
mock_db_instance = mock_db_class.return_value
# Mock the query result (A Record is basically a dictionary, so we can just make a dictionary)
mock_db_instance.query.return_value = {'number': 1, 'name': 'Bulbasaur', 'type_1': 'Fake'}
# Initialize a fresh DataSource for this test
ds = DataSource()
# Act
result = ds.get_pokemon_by_name_dictionary("Bulbasaur")
# Assert
self.assertEqual(result['number'], 1)
self.assertEqual(result['name'], 'Bulbasaur')
self.assertEqual(result['type_1'], 'Fake')
# Verify the query was called with the correct parameter
mock_db_instance.query.assert_called_once_with(
"SELECT * FROM pokemon WHERE name = :name",
name="Bulbasaur"
)
@patch('ProductionCode.datasource.records.Database')
def test_get_pokemon_by_name_not_found(self, mock_db_class):
# Setup query to return an empty list
mock_db_class.return_value.query.return_value = []
ds = DataSource()
result = ds.get_pokemon_by_name("MissingNo")
self.assertIsNone(result)
query = "SELECT * FROM pokemon WHERE name = :name"
rows = self.database.query(query, name=name)
@patch('ProductionCode.datasource.records.Database')
def test_get_pokemon_by_stat(self, mock_db_class):
# Setup the mock database instance
mock_db_instance = mock_db_class.return_value
# Mock the query result, which needs to be an object if you are going to
# call it's export
records_object = MagicMock()
# if you want to get multiple rows back, include them all with new line separators
records_object.export.return_value = "1,Bulbasaur,Fake,62\n2,Ivysaur,Fake,49"
# set your new mock object to be what is returned when your database is queried
mock_db_instance.query.return_value = records_object
# Run your function
ds = DataSource()
result = ds.get_pokemon_by_stat("attack", 2)
# Check to make sure the right things happened in your Python
expected_csv = "1,Bulbasaur,Fake,62\n2,Ivysaur,Fake,49"
self.assertEqual(result, expected_csv)
mock_db_instance.query.assert_called_once_with(
"SELECT * FROM pokemon ORDER BY attack DESC LIMIT 2"
)
Key Takeaways
- Patch where the object is used: Notice I patched
ProductionCode.datasource.records.Database. You want to intercept the import inside your module, not the records library itself. - Use
.return_valuefrequently: SinceDataSourcecallsrecords.Database(), the patch gives us the class. To control the instance created inside__init__, we usemock_db_class.return_value. The linemock_db_instance = mock_db_class.return_valueis just giving us a more convenient name for the object returned by callingrecords.Database(self.database_url) MagicMockfor Rows: Because therecordslibrary returns row objects with their own methods (like.export()), we create aMagicMock()for the row and define its behavior separately.- Pretend it’s a dictionary: If you aren’t using
exportor any otherRecordobject methods, you can just use a Python dictionary for your tests - You can adapt the above to work for your project without needing to change very much.
- Note, my tests have a lot of duplication for clarity since I’m assuming you’ll only be looking at one of them, you can definitely cut down in places!
Testing Flask routes
I have the following Flask file, using my DataSource from above:
from flask import Flask, request, render_template
from ProductionCode.datasource import DataSource
app = Flask(__name__)
ds = DataSource()
@app.route('/')
def index():
return render_template('index.html')
@app.route('/pokemon/<name>')
def display(poke_name):
result = ds.get_pokemon_by_name(poke_name)
return result
if __name__ == '__main__':
app.run()
The Test Implementation
In my test file, I will patch the ds object specifically inside the app module. First, I need to avoid the DataSource trying to contact the database:
import unittest
from unittest.mock import patch
# Patch out the database right away so constructor isn't called
db_patcher = patch('ProductionCode.datasource.records.Database')
mock_db_class = db_patcher.start()
# Safely import app without worrying about DataSource contacting database
from app import app
Then I can test with the following:
class TestFlaskApp(unittest.TestCase):
def setUp(self):
# Creates a test client so we can make requests without running a server
self.client = app.test_client()
self.client.testing = True
@patch('app.ds.get_pokemon_by_name')
def test_display_route_success(self, mock_get_pokemon):
# 1. Setup the Mock response
# We simulate what the Datasource would return (the CSV string)
mock_get_pokemon.return_value = "25,Pikachu,Electric"
# 2. Simulate a GET request to the route
response = self.client.get('/pokemon/Pikachu/')
# 3. Assertions
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.decode('utf-8'), "25,Pikachu,Electric")
# Verify the mock was called with the right argument from the URL
mock_get_pokemon.assert_called_once_with("Pikachu")
@patch('app.ds.get_pokemon_by_name')
def test_display_route_not_found(self, mock_get_pokemon):
# Setup the mock to return None (simulating a missing database record)
mock_get_pokemon.return_value = None
response = self.client.get('/pokemon/MissingNo/')
self.assertEqual(response.status_code, 200)
self.assertIn("No pokemon!", response.data.decode('utf-8'))
Key Concepts for Flask Patching
- Preventing DataSource constructor: In my current setup, I have a global
DataSourceandDataSourcecreates a connection within its constructor so I need to mock that class before importingapp. I could avoid that with refactoring to not immediately connect to the database in the constructor or with a “factory” design pattern, which we’ll discuss later in the term. app.test_client(): This acts as a “browser in a box.” It lets you trigger routes and inspect the response (status codes, headers, and body) without actually starting a web server on a port.- Patching the Instance: In the previous section, I patched the class (
records.Database). Here, I patched the method of an already existing instance (app.ds.get_pokemon_by_name). This is often easier because you aren’t messing with the constructor logic and it is ideal to patch as close to the function that you are testing as possible. - Response Decoding: Flask
response.datareturns bytes. To compare it to a string, you’ll usually need to call.decode('utf-8').
For More
There is a lot more that you can do with Mock and patch, and the best place to learn more is the Python documentation:
- patch start and stop - if you want to get rid of repetitive patching of the database
- Quick Guide - for more how how Mock and MagicMock actually work