From e157d0ce97681b28e262f5bf8f1f083a32a7bd6c Mon Sep 17 00:00:00 2001 From: Jun-te Kim Date: Mon, 3 Mar 2025 12:52:23 +0000 Subject: [PATCH] added kahlim sharepoint code --- etl/src/etl/utils/sharepoint/__init__.py | 0 etl/src/etl/utils/sharepoint/sharepoint.py | 272 +++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 etl/src/etl/utils/sharepoint/__init__.py create mode 100644 etl/src/etl/utils/sharepoint/sharepoint.py diff --git a/etl/src/etl/utils/sharepoint/__init__.py b/etl/src/etl/utils/sharepoint/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/etl/src/etl/utils/sharepoint/sharepoint.py b/etl/src/etl/utils/sharepoint/sharepoint.py new file mode 100644 index 0000000..f40d765 --- /dev/null +++ b/etl/src/etl/utils/sharepoint/sharepoint.py @@ -0,0 +1,272 @@ +""" +This file contains the functions which enable interaction with SharePoint via the API. +""" +from msal import ConfidentialClientApplication +from datetime import datetime, timedelta +import requests +from functools import wraps +import time +import logging +from io import BytesIO + +# Configure logging +logger = logging.getLogger(__name__) +if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + handler.setFormatter(formatter) + logger.addHandler(handler) +logger.setLevel(logging.INFO) + + +def handle_error(response): + """ + Handle errors based on HTTP status codes and log detailed information. + """ + try: + error_json = response.json().get('error', {}) + except ValueError: + error_json = {} + + error_code = error_json.get('code', 'unknownError') + error_message = error_json.get('message', 'No detailed error message provided.') + inner_error = error_json.get('innererror', {}) + details = error_json.get('details', []) + + logger.error(f"Error Code: {error_code}") + logger.error(f"Error Message: {error_message}") + if inner_error: + logger.error(f"Inner Error: {inner_error}") + if details: + logger.error(f"Error Details: {details}") + + if response.status_code == 401: + logger.error("Unauthorized. Token might be invalid.") + elif response.status_code == 403: + logger.error("Forbidden. Access denied to the requested resource.") + elif response.status_code == 404: + logger.error("Not Found. The requested resource doesn’t exist.") + elif response.status_code == 429: + retry_after = int(response.headers.get('Retry-After', 5)) # Default to 5 seconds if not provided + logger.warning(f"Too Many Requests. Retrying after {retry_after} seconds...") + time.sleep(retry_after) + return 'retry' + elif response.status_code in (500, 503): + retry_after = int(response.headers.get('Retry-After', 5)) # Default to 5 seconds if not provided + logger.error(f"Server error. Retrying after {retry_after} seconds...") + time.sleep(retry_after) + return 'retry' + else: + raise ValueError(f"API request failed with status code {response.status_code} - {error_message}") + + raise ValueError(f"API request failed with status code {response.status_code} - {error_message}") + + +def api_call_decorator(func): + """ + Handles various aspects of the API call, including refreshing the access token if needed and handling pagination. + :param func: The function to be decorated. + :return: The wrapped function. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + # Check and refresh the access token if needed + if self.is_access_token_expired(): + self.retrieve_access_token() + logger.info("Access token refreshed.") + + # Get the HTTP method, URL, and optionally data from the function + http_method, url, data = func(self, *args, **kwargs) + + # Initialize the results list and handle pagination if page_size is provided + results = [] + page_size = kwargs.get('page_size', None) + response_data = {} + + while url: + response = requests.request(http_method, url, headers=self.headers, json=data) + + # Handle the response + if response.status_code == 200: + response_json = response.json() # Store the response JSON + if page_size: + results.extend(response_json.get('value', [])) + url = response_json.get('@odata.nextLink', None) + else: + response_data = response_json # Capture the full response for consistency + break + else: + retry = handle_error(response) + if retry == 'retry': + continue + + if page_size: + response_data = {'value': results} + + return response_data + + except Exception as e: + logger.exception("An error occurred during the API call.") + raise e + + return wrapper + + +class SharePointClient: + access_token = None + access_token_request_timestamp = None + access_token_expiry = None + headers = None + + TIMESTAMP_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ" + + def __init__(self, tenant_id, client_id, client_secret, site_id, access_token=None, + access_token_expiration_details=None): + """ + Initializes the SharePointClient with necessary credentials and site information. + :param tenant_id: The tenant ID. + :param client_id: The client ID. + :param client_secret: The client secret. + :param site_id: The site ID. + :param access_token: The access token (optional) + :param access_token_expiration_details: The access token expiration details (optional) + """ + self.tenant_id = tenant_id + self.client_id = client_id + self.client_secret = client_secret + + if access_token: + if not access_token_expiration_details: + raise ValueError("Access token expiration details must be provided.") + self.access_token = access_token + self.set_access_token_expiration_details(access_token_expiration_details) + self.headers = { + 'Authorization': f"Bearer {self.access_token['access_token']}" + } + else: + self.retrieve_access_token() + + # Retrieve static identifiers + self.site_id = site_id + self.document_drive = self.get_documents_drive() + + def get_token_expiration_details(self): + """ + Returns the access token expiration details. Converts the datetime objects to strings for serialization. + :return: + """ + return { + 'access_token_request_timestamp': datetime.strftime( + self.access_token_request_timestamp, self.TIMESTAMP_FORMAT + ), + 'access_token_expiry': datetime.strftime(self.access_token_expiry, self.TIMESTAMP_FORMAT) + } + + def set_access_token_expiration_details(self, access_token_expiration_details): + """ + Sets the access token expiration details from a serialized dictionary. + :param access_token_expiration_details: The serialized access token expiration details. + :return: + """ + self.access_token_request_timestamp = datetime.strptime( + access_token_expiration_details['access_token_request_timestamp'], self.TIMESTAMP_FORMAT + ) + self.access_token_expiry = datetime.strptime( + access_token_expiration_details['access_token_expiry'], self.TIMESTAMP_FORMAT + ) + + def is_access_token_expired(self): + """ + Checks if the access token has expired. If it has, a new access token is retrieved. + :return: True if expired, False otherwise. + """ + return datetime.now() >= self.access_token_expiry + + def retrieve_access_token(self, refresh=False): + """ + Implements authentication using MSAL. + :param refresh: If True, force a refresh of the access token. + :return: None + """ + app = ConfidentialClientApplication( + self.client_id, + authority=f"https://login.microsoftonline.com/{self.tenant_id}", + client_credential=self.client_secret + ) + + scope = ["https://graph.microsoft.com/.default"] + + access_token_request_timestamp = datetime.now() + + if refresh: + logger.info("Forcing refresh of access token.") + token = app.acquire_token_for_client(scopes=scope) + else: + # Check if a token is already cached + token = app.acquire_token_silent(scope, account=None) + + if not token: + token = app.acquire_token_for_client(scopes=scope) + + if "access_token" not in token: + logger.error("Authentication failed.") + raise ValueError("Authentication failed") + + access_token_expiry = access_token_request_timestamp + timedelta( + seconds=token['expires_in'] - 20 + ) + + self.access_token = token + self.access_token_request_timestamp = access_token_request_timestamp + self.access_token_expiry = access_token_expiry + self.headers = { + 'Authorization': f"Bearer {self.access_token['access_token']}" + } + + logger.info("Access token retrieved successfully.") + + @api_call_decorator + def get_documents_drive(self): + """ + Get the document drive of the SharePoint site. + :return: Tuple containing HTTP method, URL, and None for data. + """ + url = f"https://graph.microsoft.com/v1.0/sites/{self.site_id}/drive" + logger.info(f"Getting document drive from URL: {url}") + return 'GET', url, None + + @api_call_decorator + def list_folder_contents(self, drive_id, folder_path: str, page_size: int = 100): + """ + This function will list the contents of a folder in SharePoint. + :param drive_id: The ID of the drive. + :param folder_path: The path of the folder. + :param page_size: The number of items per page (default is 100). + :return: Tuple containing HTTP method, URL, and None for data. + """ + url = f"https://graph.microsoft.com/v1.0/drives/{drive_id}/root:/{folder_path}:/children?$top={page_size}" + logger.info(f"Listing folder contents from URL: {url}") + return 'GET', url, None + + @staticmethod + def download_sharepoint_file(download_url): + """ + Downloads a file from the given URL and returns its content. + + :param download_url: The URL to download the file from. + :return: The content of the downloaded file. + """ + response = requests.get(download_url, stream=True) + response.raise_for_status() # Check if the request was successful + + file_content = BytesIO() + + # Read the file content into memory + for chunk in response.iter_content(chunk_size=8192): + file_content.write(chunk) + + file_content.seek(0) # Reset the file pointer to the beginning + + return file_content \ No newline at end of file