Skip to main content

Mocking in python3, mocking boto3 library, aws

· 4 min read
Hreniuc Cristian-Alexandru

An example on how you can mock things in python.

Let's assume we have the following folder structure:

 - project
- src
- test
- utils
- aws.SSMClient.AWSFactory
- and so on, check below in the code

To run the unit tests:

cd src

python3 -m unittest discover -p "*test*" -v

If you added a new subfolder in test, you need to also add an empty file called __init__.py, to be discovered by the command from above.

When adding unit tests, you can enable logging for a specific testa case if you add the following line in the beginning of the test case:

logging.disable(logging.DEBUG)

Also, in all unit tests we inject fake response from aws using mock and we also enforce that a method be called with a specific set of params. To see the mock calls made on the boto3(aws) instances you can temporary add the following prints:

print(f'\n{self.mock_ssm.mock_calls}\n')
print(f'\n{self.mock_ec2.mock_calls}\n')
print(f'\n{self.mock_sqs.mock_calls}\n')

This will print all calls made on the aws instances, this way you can verify and add mock expectations in your tests.

In the file below you will notice that we import fake data or Helpers, those things are just strings that we want to return or the call params of a mock method, eg:

# Helper call for mock create param from boto3
create_param_call = call(
Name='/REGION_SETUP/cristi/unique-id',
Value='cristi/unique-id',
Description='cristi/unique-id',
Type='String', Tier='Standard', DataType='text')

# Fake data: fake_ssm
internal_error_result_param = {
"Parameter": {
"Name": "/REGION_SETUP/cristi/unique-id",
"Type": "String",
"Value": "{'statusCode': 500, 'body': 'Internal error'}",
"Version": 1,
"LastModifiedDate": "datetime.datetime(2022, 11, 10, 15, 20, 16, 303000, tzinfo=tzlocal())",
"ARN": "arn:aws:ssm:eu-west-1:968428508743:parameter/REGION_SETUP/cristi/unique-id67635014db534a10ac772adc42b8f114",
"DataType": "text",
},
"ResponseMetadata": {
"RequestId": "252b3b53-eefa-4d0d-aa44-be953bac3b45",
"HTTPStatusCode": 200,
"HTTPHeaders": {
"server": "Server",
"date": "Thu, 10 Nov 2022 15:20:16 GMT",
"content-type": "application/x-amz-json-1.1",
"content-length": "464",
"connection": "keep-alive",
"x-amzn-requestid": "252b3b53-eefa-4d0d-aa44-be953bac3b45",
},
"RetryAttempts": 0,
},
}

Example test file:

import logging
import unittest
from unittest.mock import MagicMock, patch
import test.fake_data.events.trigger.event as fake_event
import test.fake_data.ssm as fake_ssm
from utils.trigger import Trigger
import test.trigger.Helper as Helper


# python3 -m unittest discover -p "*test*" -v


class TestTrigger(unittest.TestCase):

def setUp(self):
logging.disable(logging.CRITICAL)
self.mock_ssm = MagicMock(name="mock_ssm")
self.mock_sqs = MagicMock(name="mock_sqs")
config_ssm = {'return_value': self.mock_ssm}
config_sqs = {'return_value': self.mock_sqs}
config_uuid = {'return_value': Helper.generated_uuid}

self.aws_ssm_patcher = patch(
'utils.aws.SSMClient.AWSFactory.get_ssm_instance', **config_ssm)
self.aws_sqs_patcher = patch(
'utils.aws.SQSClient.AWSFactory.get_sqs_instance', **config_sqs)
self.uuid_patcher = patch(
'utils.trigger.uuid.uuid4', **config_uuid)
self.timeout_patcher = patch(
'utils.trigger.Config.TRIGGER_MAX_TIME_EXECUTION_DURATION', 0.08)

self.mock_SSM_AWSFactory = self.aws_ssm_patcher.start()
self.mock_SQS_AWSFactory = self.aws_sqs_patcher.start()
self.uuid_patcher.start()
self.timeout_patcher.start()

self.trigger = Trigger()

def test_run_timeout(self):
self.mock_ssm.put_parameter.return_value = "response"
self.mock_ssm.get_parameter.return_value = fake_ssm.running_param

self.assertEqual(
self.trigger.run(fake_event.delete_event),
{'statusCode': 408, 'body': 'Request Timeout'})

self.mock_sqs.get_queue_url.assert_called()
self.mock_sqs.send_message.assert_has_calls(
[Helper.sqs_send_message_call])
self.assertEqual(self.mock_sqs.send_message.call_count, 1)
# 2 * call.get_queue_url + send_message
self.assertEqual(len(self.mock_sqs.mock_calls), 3)

# Get param
self.mock_ssm.get_parameter.assert_called_with(
Name=Helper.param_name, WithDecryption=True)
# Create param and rename to failed
self.mock_ssm.put_parameter.assert_has_calls(
[Helper.create_param_call, Helper.rename_to_failed_call])
self.assertEqual(self.mock_ssm.put_parameter.call_count, 2)
# Delete param
self.mock_ssm.delete_parameter.assert_called_once_with(
Name=Helper.param_name)

def test_run_server_error_in_trigger(self):
self.mock_ssm.put_parameter.return_value = "response"
# This will trigger an exception
self.mock_ssm.get_parameter.return_value = "param"

self.assertEqual(
self.trigger.run(fake_event.delete_event),
{'statusCode': 500, 'body': 'Internal error'})

self.mock_sqs.get_queue_url.assert_called()
self.mock_sqs.send_message.assert_has_calls(
[Helper.sqs_send_message_call])
self.assertEqual(self.mock_sqs.send_message.call_count, 1)
# 2 * call.get_queue_url + send_message
self.assertEqual(len(self.mock_sqs.mock_calls), 3)

# Get param
self.mock_ssm.get_parameter.assert_called_with(
Name=Helper.param_name, WithDecryption=True)

self.mock_ssm.put_parameter.assert_has_calls(
[Helper.create_param_call])
self.assertEqual(self.mock_ssm.put_parameter.call_count, 1)

def test_run_server_error_in_executor(self):

self.mock_ssm.put_parameter.return_value = "response"
# This will trigger an exception
self.mock_ssm.get_parameter.return_value = fake_ssm.internal_error_result_param

self.assertEqual(
self.trigger.run(fake_event.delete_event),
{'statusCode': 500, 'body': 'Internal error'})

self.mock_sqs.get_queue_url.assert_called()
self.mock_sqs.send_message.assert_has_calls(
[Helper.sqs_send_message_call])
self.assertEqual(self.mock_sqs.send_message.call_count, 1)
# 2 * call.get_queue_url + send_message
self.assertEqual(len(self.mock_sqs.mock_calls), 3)

# Get param
self.mock_ssm.get_parameter.assert_called_with(
Name=Helper.param_name, WithDecryption=True)
self.assertEqual(self.mock_ssm.get_parameter.call_count, 3)

self.mock_ssm.put_parameter.assert_has_calls(
[Helper.create_param_call, Helper.rename_to_failed_with_server_error_call])
self.assertEqual(self.mock_ssm.put_parameter.call_count, 2)
# Delete param
self.mock_ssm.delete_parameter.assert_called_once_with(
Name=Helper.param_name)
# Total ssm calls
self.assertEqual(len(self.mock_ssm.mock_calls), 6)

def test_run_get_param_returns_another_exception(self):
self.mock_ssm.put_parameter.return_value = "response"
# This will trigger an exception
self.mock_ssm.get_parameter.side_effect = [
fake_event.delete_event, Helper.param_exception]

self.assertEqual(
self.trigger.run(fake_event.delete_event),
{'statusCode': 500, 'body': 'Internal error'})

self.mock_sqs.get_queue_url.assert_called()
self.mock_sqs.send_message.assert_has_calls(
[Helper.sqs_send_message_call])
self.assertEqual(self.mock_sqs.send_message.call_count, 1)
# 2 * call.get_queue_url + send_message
self.assertEqual(len(self.mock_sqs.mock_calls), 3)

# Get param
self.mock_ssm.get_parameter.assert_called_with(
Name=Helper.param_name, WithDecryption=True)
self.assertEqual(self.mock_ssm.get_parameter.call_count, 2)

self.mock_ssm.put_parameter.assert_has_calls(
[Helper.create_param_call])
self.assertEqual(self.mock_ssm.put_parameter.call_count, 1)
# Total ssm calls
self.assertEqual(len(self.mock_ssm.mock_calls), 3)

def test_run_success(self):

self.mock_ssm.put_parameter.return_value = "response"
# This will trigger an exception
self.mock_ssm.get_parameter.side_effect = [
fake_event.delete_event, Helper.param_not_found_exception]

self.assertEqual(
self.trigger.run(fake_event.delete_event),
{'statusCode': 200})

self.mock_sqs.get_queue_url.assert_called()
self.mock_sqs.send_message.assert_has_calls(
[Helper.sqs_send_message_call])
self.assertEqual(self.mock_sqs.send_message.call_count, 1)
# 2 * call.get_queue_url + send_message
self.assertEqual(len(self.mock_sqs.mock_calls), 3)

# Get param
self.mock_ssm.get_parameter.assert_called_with(
Name=Helper.param_name, WithDecryption=True)
self.assertEqual(self.mock_ssm.get_parameter.call_count, 2)

self.mock_ssm.put_parameter.assert_has_calls(
[Helper.create_param_call])
self.assertEqual(self.mock_ssm.put_parameter.call_count, 1)
# Total ssm calls
self.assertEqual(len(self.mock_ssm.mock_calls), 3)

def tearDown(self):
self.aws_ssm_patcher.stop()
self.aws_sqs_patcher.stop()
self.uuid_patcher.stop()
self.timeout_patcher.stop()


if __name__ == '__main__':
unittest.main()