From 3693ade68ff94b8ce2398fa2eb31c981b78b72f6 Mon Sep 17 00:00:00 2001 From: Khalim Conn-Kowlessar Date: Fri, 16 Jun 2023 09:42:10 +0100 Subject: [PATCH] added utils test --- model_data/tests/test_utils.py | 49 ++++++++++++++++++++++++++++++++++ model_data/utils.py | 6 ++++- 2 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 model_data/tests/test_utils.py diff --git a/model_data/tests/test_utils.py b/model_data/tests/test_utils.py new file mode 100644 index 00000000..ea8d0fd0 --- /dev/null +++ b/model_data/tests/test_utils.py @@ -0,0 +1,49 @@ +import logging +from io import StringIO +from unittest.mock import patch +from model_data.utils import setup_logger + + +class TestLogger: + def test_setup_logger_default(self): + log_stream = StringIO() + handler = logging.StreamHandler(log_stream) + logger = setup_logger() + logger.addHandler(handler) + + # log something + logger.info("Hello World!") + + log_stream.seek(0) + # assert that log was written + assert log_stream.read() == "Hello World!\n" + # remove the handler after use + logger.removeHandler(handler) + + @patch('logging.FileHandler') + def test_setup_logger_file(self, mock_file_handler): + # setup the logger + logger = setup_logger(log_file='test.log', overwrite_handler=True) + + # assert FileHandler was called correctly + mock_file_handler.assert_called_once_with('test.log') + + # clean up after use + for handler in logger.handlers[:]: + handler.close() + logger.removeHandler(handler) + + def test_setup_logger_loglevel(self): + log_stream = StringIO() + handler = logging.StreamHandler(log_stream) + logger = setup_logger(level=logging.DEBUG) + logger.addHandler(handler) + + # log something + logger.debug("Hello World!") + + log_stream.seek(0) + # assert that log was written + assert log_stream.read() == "Hello World!\n" + # remove the handler after use + logger.removeHandler(handler) diff --git a/model_data/utils.py b/model_data/utils.py index 9fe04c89..d643f36a 100644 --- a/model_data/utils.py +++ b/model_data/utils.py @@ -1,11 +1,15 @@ import logging -def setup_logger(log_file=None, level=logging.INFO): +def setup_logger(log_file=None, level=logging.INFO, overwrite_handler=False): # Create a logger and set the logging level logger = logging.getLogger() logger.setLevel(level) + # if logger already has handlers, just return it + if logger.hasHandlers() and not overwrite_handler: + return logger + # Define the log message format log_format = "%(asctime)s [%(levelname)s] %(message)s" date_format = "%Y-%m-%d %H:%M:%S"