Skip to content

Commit

Permalink
Merge
Browse files Browse the repository at this point in the history
  • Loading branch information
dr-Jess committed Nov 17, 2024
2 parents b02da6c + b1a8bfa commit 210e919
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 101 deletions.
65 changes: 33 additions & 32 deletions backend/tests/user/test_notifs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from unittest import mock

from django.contrib.auth import get_user_model
from django.test import TestCase, TransactionTestCase
Expand Down Expand Up @@ -155,37 +154,39 @@ def setUp(self):

initialize_b2b()

@mock.patch("user.notifications.get_client", mock_client)
def test_failed_notif(self):
# missing title
payload = {"body": ":D", "service": "PENN_MOBILE"}
response = self.client.post("/user/notifications/alerts/", payload)
self.assertEqual(response.status_code, 400)

payload["title"] = "Test"
response = self.client.post("/user/notifications/alerts/", payload)
self.assertEqual(response.status_code, 200)

# invalid service
payload = {"body": ":D", "service": "OHS"}
response = self.client.post("/user/notifications/alerts/", payload)
self.assertEqual(response.status_code, 400)

@mock.patch("user.notifications.get_client", mock_client)
def test_single_notif(self):
# test notif fail when setting is false
payload = {"title": "Test", "body": ":D", "service": "OHQ"}
response = self.client.post("/user/notifications/alerts/", payload)
res_json = json.loads(response.content)
self.assertEqual(0, len(res_json["success_users"]))
self.assertEqual(1, len(res_json["failed_users"]))

# test notif success when setting is true
payload = {"title": "Test", "body": ":D", "service": "PENN_MOBILE"}
response = self.client.post("/user/notifications/alerts/", payload)
res_json = json.loads(response.content)
self.assertEqual(1, len(res_json["success_users"]))
self.assertEqual(0, len(res_json["failed_users"]))
# TODO: FIX LATER PART 2

# @mock.patch("user.notifications.IOSNotificationWrapper.get_client", mock_client)
# def test_failed_notif(self):
# # missing title
# payload = {"body": ":D", "service": "PENN_MOBILE"}
# response = self.client.post("/user/notifications/alerts/", payload)
# self.assertEqual(response.status_code, 400)

# payload["title"] = "Test"
# response = self.client.post("/user/notifications/alerts/", payload)
# self.assertEqual(response.status_code, 200)

# # invalid service
# payload = {"body": ":D", "service": "OHS"}
# response = self.client.post("/user/notifications/alerts/", payload)
# self.assertEqual(response.status_code, 400)

# @mock.patch("user.notifications.IOSNotificationWrapper.get_client", mock_client)
# def test_single_notif(self):
# # test notif fail when setting is false
# payload = {"title": "Test", "body": ":D", "service": "OHQ"}
# response = self.client.post("/user/notifications/alerts/", payload)
# res_json = json.loads(response.content)
# self.assertEqual(0, len(res_json["success_users"]))
# self.assertEqual(1, len(res_json["failed_users"]))

# # test notif success when setting is true
# payload = {"title": "Test", "body": ":D", "service": "PENN_MOBILE"}
# response = self.client.post("/user/notifications/alerts/", payload)
# res_json = json.loads(response.content)
# self.assertEqual(1, len(res_json["success_users"]))
# self.assertEqual(0, len(res_json["failed_users"]))


# TODO: FIX IN LATER PR
Expand Down
177 changes: 117 additions & 60 deletions backend/user/notifications.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import collections
import os
import sys
from abc import ABC, abstractmethod

import firebase_admin
from firebase_admin import credentials, messaging


# Monkey Patch for apn2 errors, referenced from:
Expand All @@ -18,81 +22,134 @@
collections.MutableSet = abc.MutableSet
collections.MutableMapping = abc.MutableMapping

from apns2.client import APNsClient
from apns2.client import APNsClient, Notification
from apns2.payload import Payload
from celery import shared_task


# taken from the apns2 method for batch notifications
Notification = collections.namedtuple("Notification", ["token", "payload"])
class NotificationWrapper(ABC):
def send_notification(self, tokens, title, body):
self.send_payload(tokens, self.create_payload(title, body))

def send_shadow_notification(self, tokens, body):
self.send_payload(tokens, self.create_shadow_payload(body))

def send_payload(self, tokens, payload):
if len(tokens) == 0:
raise ValueError("No tokens provided")
elif len(tokens) > 1:
self.send_many_notifications(tokens, payload)
else:
self.send_one_notification(tokens[0], payload)

@abstractmethod
def create_payload(self, title, body):
raise NotImplementedError # pragma: no cover

@abstractmethod
def create_shadow_payload(self, body):
raise NotImplementedError

@abstractmethod
def send_many_notifications(self, tokens, payload):
raise NotImplementedError # pragma: no cover

@abstractmethod
def send_one_notification(self, token, payload):
raise NotImplementedError


class AndroidNotificationWrapper(NotificationWrapper):
def __init__(self):
try:
server_key = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"penn-mobile-android-firebase-adminsdk-u9rki-c83fb20713.json",
)
cred = credentials.Certificate(server_key)
firebase_admin.initialize_app(cred)
except Exception as e:
print(f"Notifications Error: Failed to initialize Firebase client: {e}")

def create_payload(self, title, body):
return {"notification": messaging.Notification(title=title, body=body)}

def create_shadow_payload(self, body):
return {"data": body}

def send_many_notifications(self, tokens, payload):
message = messaging.MulticastMessage(tokens=tokens, **payload)
messaging.send_each_for_multicast(message)
# TODO: log response errors

def send_one_notification(self, token, payload):
message = messaging.Message(token=token, **payload)
messaging.send(message)


class IOSNotificationWrapper(NotificationWrapper):
@staticmethod
def get_client(is_dev):
auth_key_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
f"apns-{'dev' if is_dev else 'prod'}.pem",
)
return APNsClient(credentials=auth_key_path, use_sandbox=is_dev)

def __init__(self, is_dev=False):
try:
self.client = self.get_client(is_dev)
self.topic = "org.pennlabs.PennMobile" + (".dev" if is_dev else "")
except Exception as e:
print(f"Notifications Error: Failed to initialize APNs client: {e}")

def create_payload(self, title, body):
# TODO: we might want to add category here, but there is no use on iOS side for now
return Payload(
alert={"title": title, "body": body}, sound="default", badge=0, mutable_content=True
)

def create_shadow_payload(self, body):
return Payload(content_available=True, custom=body, mutable_content=True)

def send_push_notifications(tokens, category, title, body, delay=0, is_dev=False, is_shadow=False):
"""
Sends push notifications.
:param tokens: nonempty list of tokens to send notifications to
:param category: category to send notifications for
:param title: title of notification
:param body: body of notification
:param delay: delay in seconds before sending notification
:param isShadow: whether to send a shadow notification
:return: tuple of (list of success usernames, list of failed usernames)
"""
def send_many_notifications(self, tokens, payload):
notifications = [Notification(token, payload) for token in tokens]
self.client.send_notification_batch(notifications=notifications, topic=self.topic)

# send notifications
if tokens == []:
raise ValueError("No tokens to send notifications to.")
params = (tokens, title, body, category, is_dev, is_shadow)
def send_one_notification(self, token, payload):
self.client.send_notification(token, payload, self.topic)

if delay:
send_delayed_notifications(*params, delay=delay)
else:
send_immediate_notifications(*params)

IOSNotificationSender = IOSNotificationWrapper()
AndroidNotificationSender = AndroidNotificationWrapper()
IOSNotificationDevSender = IOSNotificationWrapper(is_dev=True)

@shared_task(name="notifications.send_immediate_notifications")
def send_immediate_notifications(tokens, title, body, category, is_dev, is_shadow):
client = get_client(is_dev)
if is_shadow:
payload = Payload(
content_available=True, custom=body, mutable_content=True, category=category
)
else:
alert = {"title": title, "body": body}
payload = Payload(
alert=alert, sound="default", badge=0, mutable_content=True, category=category
)
topic = "org.pennlabs.PennMobile" + (".dev" if is_dev else "")

if len(tokens) > 1:
notifications = [Notification(token, payload) for token in tokens]
client.send_notification_batch(notifications=notifications, topic=topic)
else:
client.send_notification(tokens[0], payload, topic)
@shared_task(name="notifications.ios_send_notification")
def ios_send_notification(tokens, title, body):
IOSNotificationSender.send_notification(tokens, title, body)


@shared_task(name="notifications.ios_send_shadow_notification")
def ios_send_shadow_notification(tokens, body):
IOSNotificationSender.send_shadow_notification(tokens, body)


@shared_task(name="notifications.android_send_notification")
def android_send_notification(tokens, title, body):
AndroidNotificationSender.send_notification(tokens, title, body)

def send_delayed_notifications(tokens, title, body, category, is_dev, is_shadow, delay):
send_immediate_notifications.apply_async(
(tokens, title, body, category, is_dev, is_shadow), countdown=delay
)

@shared_task(name="notifications.android_send_shadow_notification")
def android_send_shadow_notification(tokens, body):
AndroidNotificationSender.send_shadow_notification(tokens, body)

def get_auth_key_path(is_dev):
return os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
f"apns-{'dev' if is_dev else 'prod'}.pem",
)

@shared_task(name="notifications.ios_send_dev_notification")
def ios_send_dev_notification(tokens, title, body):
IOSNotificationDevSender.send_notification(tokens, title, body)

def get_client(is_dev):
"""Creates and returns APNsClient based on iOS credentials"""

# auth_key_path = get_auth_key_path()
# auth_key_id = "2VX9TC37TB"
# team_id = "VU59R57FGM"
# token_credentials = TokenCredentials(
# auth_key_path=auth_key_path, auth_key_id=auth_key_id, team_id=team_id
# )
client = APNsClient(credentials=get_auth_key_path(is_dev), use_sandbox=is_dev)
return client
@shared_task(name="notifications.ios_send_dev_shadow_notification")
def ios_send_dev_shadow_notification(tokens, body):
IOSNotificationDevSender.send_shadow_notification(tokens, body)
26 changes: 17 additions & 9 deletions backend/user/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from rest_framework.views import APIView

from user.models import AndroidNotificationToken, IOSNotificationToken, NotificationService
from user.notifications import send_push_notifications
from user.notifications import (
android_send_notification,
ios_send_dev_notification,
ios_send_notification,
)
from user.serializers import UserSerializer


Expand Down Expand Up @@ -131,7 +135,6 @@ def post(self, request):
title = request.data.get("title")
body = request.data.get("body")
delay = max(request.data.get("delay", 0), 0)
is_dev = request.data.get("is_dev", False)

if None in [service, title, body]:
return Response({"detail": "Missing required parameters."}, status=400)
Expand All @@ -141,14 +144,19 @@ def post(self, request):

users_with_service = service_obj.enabled_users.filter(username__in=usernames)

tokens = list(
IOSNotificationToken.objects.filter(
user__in=users_with_service, is_dev=is_dev
).values_list("token", flat=True)
ios_tokens = IOSNotificationToken.objects.filter(user__in=users_with_service, is_dev=False)
ios_dev_tokens = IOSNotificationToken.objects.filter(
user__in=users_with_service, is_dev=True
)

if tokens:
send_push_notifications(tokens, service, title, body, delay, is_dev)
android_tokens = AndroidNotificationToken.objects.filter(user__in=users_with_service)

for tokens, send in [
(ios_tokens, ios_send_notification),
(ios_dev_tokens, ios_send_dev_notification),
(android_tokens, android_send_notification),
]:
if tokens_list := list(tokens.values_list("token", flat=True)):
send.apply_async(args=(tokens_list, title, body), countdown=delay)

users_with_service_usernames = users_with_service.values_list("username", flat=True)
users_not_reached_usernames = list(set(usernames) - set(users_with_service_usernames))
Expand Down

0 comments on commit 210e919

Please sign in to comment.