mirror of
https://github.com/Hestia-Homes/Model.git
synced 2026-06-08 11:17:27 +00:00
346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""
|
||
This file contains the functions which enable interaction with SharePoint via the API.
|
||
|
||
Documentation to get api_id:
|
||
https://answers.microsoft.com/en-us/msoffice/forum/all/what-is-the-best-way-to-findout-the-share-point/7b2d4183-4188-4cd5-8441-dd93207c5a01
|
||
"""
|
||
|
||
from typing import Any, BinaryIO, Dict, Optional
|
||
|
||
from msal import ConfidentialClientApplication
|
||
from datetime import datetime, timedelta
|
||
import requests
|
||
from functools import wraps
|
||
import time
|
||
from io import BytesIO
|
||
|
||
from utils.logger import setup_logger
|
||
|
||
# Api Documentation: https://learn.microsoft.com/en-us/graph/api/drive-get?view=graph-rest-1.0&tabs=http
|
||
|
||
logger = setup_logger()
|
||
|
||
|
||
def handle_error(response):
|
||
"""
|
||
Handle errors based on HTTP status codes and log detailed information.
|
||
"""
|
||
try:
|
||
error_json = response.json().get("error", {})
|
||
except ValueError:
|
||
error_json = {}
|
||
|
||
error_code = error_json.get("code", "unknownError")
|
||
error_message = error_json.get("message", "No detailed error message provided.")
|
||
inner_error = error_json.get("innererror", {})
|
||
details = error_json.get("details", [])
|
||
|
||
logger.error(f"Error Code: {error_code}")
|
||
logger.error(f"Error Message: {error_message}")
|
||
if inner_error:
|
||
logger.error(f"Inner Error: {inner_error}")
|
||
if details:
|
||
logger.error(f"Error Details: {details}")
|
||
|
||
if response.status_code == 401:
|
||
logger.error("Unauthorized. Token might be invalid.")
|
||
elif response.status_code == 403:
|
||
logger.error("Forbidden. Access denied to the requested resource.")
|
||
elif response.status_code == 404:
|
||
logger.error("Not Found. The requested resource doesn’t exist.")
|
||
elif response.status_code == 429:
|
||
retry_after = int(
|
||
response.headers.get("Retry-After", 5)
|
||
) # Default to 5 seconds if not provided
|
||
logger.warning(f"Too Many Requests. Retrying after {retry_after} seconds...")
|
||
time.sleep(retry_after)
|
||
return "retry"
|
||
elif response.status_code in (500, 503):
|
||
retry_after = int(
|
||
response.headers.get("Retry-After", 5)
|
||
) # Default to 5 seconds if not provided
|
||
logger.error(f"Server error. Retrying after {retry_after} seconds...")
|
||
time.sleep(retry_after)
|
||
return "retry"
|
||
else:
|
||
raise ValueError(
|
||
f"API request failed with status code {response.status_code} - {error_message}"
|
||
)
|
||
|
||
raise ValueError(
|
||
f"API request failed with status code {response.status_code} - {error_message}"
|
||
)
|
||
|
||
|
||
def api_call_decorator(func):
|
||
"""
|
||
Handles various aspects of the API call, including refreshing the access token if needed and handling pagination.
|
||
:param func: The function to be decorated.
|
||
:return: The wrapped function.
|
||
"""
|
||
|
||
@wraps(func)
|
||
def wrapper(self, *args, **kwargs):
|
||
try:
|
||
# Check and refresh the access token if needed
|
||
if self.is_access_token_expired():
|
||
self.retrieve_access_token()
|
||
logger.debug("Access token refreshed.")
|
||
|
||
# Get the HTTP method, URL, and optionally data from the function
|
||
http_method, url, data = func(self, *args, **kwargs)
|
||
|
||
# Initialize the results list and handle pagination if page_size is provided
|
||
results = []
|
||
page_size = kwargs.get("page_size", None)
|
||
response_data = {}
|
||
|
||
while url:
|
||
response = requests.request(
|
||
http_method, url, headers=self.headers, json=data
|
||
)
|
||
|
||
# Handle the response
|
||
if response.status_code == 200 or response.status_code == 201:
|
||
response_json = response.json() # Store the response JSON
|
||
if page_size:
|
||
results.extend(response_json.get("value", []))
|
||
url = response_json.get("@odata.nextLink", None)
|
||
else:
|
||
response_data = (
|
||
response_json # Capture the full response for consistency
|
||
)
|
||
break
|
||
else:
|
||
retry = handle_error(response)
|
||
if retry == "retry":
|
||
continue
|
||
|
||
if page_size:
|
||
response_data = {"value": results}
|
||
|
||
return response_data
|
||
|
||
except Exception as e:
|
||
logger.exception("An error occurred during the API call.")
|
||
raise e
|
||
|
||
return wrapper
|
||
|
||
|
||
class SharePointClient:
|
||
access_token = None
|
||
access_token_request_timestamp = None
|
||
access_token_expiry = None
|
||
headers = None
|
||
|
||
TIMESTAMP_FORMAT = "%Y-%m-%dT%H:%M:%S.%fZ"
|
||
|
||
def __init__(
|
||
self,
|
||
tenant_id,
|
||
client_id,
|
||
client_secret,
|
||
site_id,
|
||
access_token=None,
|
||
access_token_expiration_details=None,
|
||
):
|
||
"""
|
||
Initializes the SharePointClient with necessary credentials and site information.
|
||
:param tenant_id: The tenant ID.
|
||
:param client_id: The client ID.
|
||
:param client_secret: The client secret.
|
||
:param site_id: The site ID.
|
||
:param access_token: The access token (optional)
|
||
:param access_token_expiration_details: The access token expiration details (optional)
|
||
"""
|
||
self.tenant_id = tenant_id
|
||
self.client_id = client_id
|
||
self.client_secret = client_secret
|
||
|
||
if access_token:
|
||
if not access_token_expiration_details:
|
||
raise ValueError("Access token expiration details must be provided.")
|
||
self.access_token = access_token
|
||
self.set_access_token_expiration_details(access_token_expiration_details)
|
||
self.headers = {
|
||
"Authorization": f"Bearer {self.access_token['access_token']}"
|
||
}
|
||
else:
|
||
self.retrieve_access_token()
|
||
|
||
# Retrieve static identifiers
|
||
self.site_id = site_id
|
||
self.document_drive = self.get_documents_drive()
|
||
self.document_drive_id = self.document_drive["id"]
|
||
|
||
def get_token_expiration_details(self):
|
||
"""
|
||
Returns the access token expiration details. Converts the datetime objects to strings for serialization.
|
||
:return:
|
||
"""
|
||
return {
|
||
"access_token_request_timestamp": datetime.strftime(
|
||
self.access_token_request_timestamp, self.TIMESTAMP_FORMAT
|
||
),
|
||
"access_token_expiry": datetime.strftime(
|
||
self.access_token_expiry, self.TIMESTAMP_FORMAT
|
||
),
|
||
}
|
||
|
||
def set_access_token_expiration_details(self, access_token_expiration_details):
|
||
"""
|
||
Sets the access token expiration details from a serialized dictionary.
|
||
:param access_token_expiration_details: The serialized access token expiration details.
|
||
:return:
|
||
"""
|
||
self.access_token_request_timestamp = datetime.strptime(
|
||
access_token_expiration_details["access_token_request_timestamp"],
|
||
self.TIMESTAMP_FORMAT,
|
||
)
|
||
self.access_token_expiry = datetime.strptime(
|
||
access_token_expiration_details["access_token_expiry"],
|
||
self.TIMESTAMP_FORMAT,
|
||
)
|
||
|
||
def is_access_token_expired(self):
|
||
"""
|
||
Checks if the access token has expired. If it has, a new access token is retrieved.
|
||
:return: True if expired, False otherwise.
|
||
"""
|
||
return datetime.now() >= self.access_token_expiry
|
||
|
||
def retrieve_access_token(self, refresh=False):
|
||
"""
|
||
Implements authentication using MSAL.
|
||
:param refresh: If True, force a refresh of the access token.
|
||
:return: None
|
||
"""
|
||
app = ConfidentialClientApplication(
|
||
self.client_id,
|
||
authority=f"https://login.microsoftonline.com/{self.tenant_id}",
|
||
client_credential=self.client_secret,
|
||
)
|
||
|
||
scope = ["https://graph.microsoft.com/.default"]
|
||
|
||
access_token_request_timestamp = datetime.now()
|
||
|
||
if refresh:
|
||
logger.debug("Forcing refresh of access token.")
|
||
token = app.acquire_token_for_client(scopes=scope)
|
||
else:
|
||
# Check if a token is already cached
|
||
token = app.acquire_token_silent(scope, account=None)
|
||
|
||
if not token:
|
||
token = app.acquire_token_for_client(scopes=scope)
|
||
|
||
if "access_token" not in token:
|
||
logger.error("Authentication failed.")
|
||
raise ValueError("Authentication failed")
|
||
|
||
access_token_expiry = access_token_request_timestamp + timedelta(
|
||
seconds=token["expires_in"] - 20
|
||
)
|
||
|
||
self.access_token = token
|
||
self.access_token_request_timestamp = access_token_request_timestamp
|
||
self.access_token_expiry = access_token_expiry
|
||
self.headers = {"Authorization": f"Bearer {self.access_token['access_token']}"}
|
||
|
||
# logger.debug("Access token retrieved successfully.")
|
||
|
||
@api_call_decorator
|
||
def get_documents_drive(self):
|
||
"""
|
||
Get the document drive of the SharePoint site.
|
||
:return: Tuple containing HTTP method, URL, and None for data.
|
||
"""
|
||
url = f"https://graph.microsoft.com/v1.0/sites/{self.site_id}/drive"
|
||
# logger.debug(f"Getting document drive from URL: {url}")
|
||
return "GET", url, None
|
||
|
||
@api_call_decorator
|
||
def list_folder_contents(
|
||
self, folder_path: str, page_size: int = 100
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
GET drive/root/children
|
||
|
||
This function will list the contents of a folder in SharePoint.
|
||
:param drive_id: The ID of the drive.
|
||
:param folder_path: The path of the folder.
|
||
:param page_size: The number of items per page (default is 100).
|
||
:return: Tuple containing HTTP method, URL, and None for data.
|
||
"""
|
||
url = f"https://graph.microsoft.com/v1.0/drives/{self.document_drive_id}/root:/{folder_path}:/children?$top={page_size}"
|
||
# logger.debug(f"Listing folder contents from URL: {url}")
|
||
return "GET", url, None
|
||
|
||
@api_call_decorator
|
||
def create_folder(self, file_name: str, folder_path: str) -> Dict[str, Any]:
|
||
"""
|
||
POST https://graph.microsoft.com/v1.0/me/drive/root/children
|
||
Content-Type: application/json
|
||
{
|
||
"name": "New Folder",
|
||
"folder": { },
|
||
"@microsoft.graph.conflictBehavior": "rename"
|
||
}
|
||
|
||
"""
|
||
data: Dict[str, Any] = {
|
||
"name": file_name,
|
||
"folder": {},
|
||
"@microsoft.graph.conflictBehavior": "rename",
|
||
}
|
||
url = f"https://graph.microsoft.com/v1.0/drives/{self.document_drive_id}/root:/{folder_path}:/children"
|
||
|
||
return "POST", url, data
|
||
|
||
def upload_file(
|
||
self, file_name: str, file_stream: BinaryIO, sharepoint_parent_id: str
|
||
) -> Optional[Dict[str, Any]]:
|
||
"""
|
||
Uploads a file to SharePoint using the Graph API.
|
||
PUT /drives/{drive-id}/root:/{path-to-file}:/content
|
||
|
||
:param file_name: Name of the file to upload
|
||
:param sharepoint_path: Path within the SharePoint site (folder path)
|
||
:param file_stream: File content as a binary stream (e.g., BytesIO or open(file, 'rb'))
|
||
:return: Response JSON from the API
|
||
"""
|
||
url = f"https://graph.microsoft.com/v1.0/drives/{self.document_drive_id}/root:/{sharepoint_parent_id}/{file_name}:/content"
|
||
# logger.debug(f"Uploading file to URL: {url}")
|
||
|
||
response = requests.put(url, headers=self.headers, data=file_stream)
|
||
|
||
if response.status_code in (200, 201):
|
||
# logger.info(f"File '{file_name}' uploaded successfully.")
|
||
return response.json()
|
||
else:
|
||
retry = handle_error(response)
|
||
if retry == "retry":
|
||
return self.upload_file(file_name, sharepoint_parent_id, file_stream)
|
||
|
||
@staticmethod
|
||
def download_sharepoint_file(download_url):
|
||
"""
|
||
Downloads a file from the given URL and returns its content.
|
||
|
||
:param download_url: The URL to download the file from.
|
||
:return: The content of the downloaded file.
|
||
"""
|
||
response = requests.get(download_url, stream=True)
|
||
response.raise_for_status() # Check if the request was successful
|
||
|
||
file_content = BytesIO()
|
||
|
||
# Read the file content into memory
|
||
for chunk in response.iter_content(chunk_size=8192):
|
||
file_content.write(chunk)
|
||
|
||
file_content.seek(0) # Reset the file pointer to the beginning
|
||
|
||
return file_content
|