summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--COMMANDS.md22
-rw-r--r--README.md1
-rw-r--r--db.py11
-rw-r--r--main.py103
4 files changed, 126 insertions, 11 deletions
diff --git a/COMMANDS.md b/COMMANDS.md
new file mode 100644
index 0000000..8479ac0
--- /dev/null
+++ b/COMMANDS.md
@@ -0,0 +1,22 @@
+# Commands
+
+## `register`
+Register an account.
+
+### Required fields
+- `username`: 1-20 characters, a-z0-9-_
+- `password`: 1-255 characters
+- `invite_code`: 16 characters
+
+## `login_pswd`
+Log in using a password.
+
+### Required fields
+- `username`: 1-20 characters
+- `password`: 1-255 characters
+
+## `login_token`
+Log in using a token.
+
+### Required fields
+- `token`: 32-127 characters
\ No newline at end of file
diff --git a/README.md b/README.md
index 6181506..d009856 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@ soktdeer rewrite
   - [ ] get message
   - [ ] chat history (v1)
 - [ ] get inbox
+- [ ] client names
 ### moderation (from hydrogen)
 - [ ] bans
 - [ ] invite codes
diff --git a/db.py b/db.py
index 3feef6c..88dc79b 100644
--- a/db.py
+++ b/db.py
@@ -4,6 +4,7 @@ from passlib.hash import scrypt
 from pymongo.mongo_client import MongoClient
 from pymongo.server_api import ServerApi
 from dotenv import load_dotenv
+import time
 
 load_dotenv()
 
@@ -63,16 +64,14 @@ class acc:
             return "fail"
         return True
 
-    def verify(username, token):
-        user = usersd.find_one({"username": username})
+    def verify(token):
+        user = usersd.find_one({"secure.token": token})
         if not user:
             return "notExists"
         if user["banned_until"] > round(time.time()):
-            return "banned"
-        elif user["secure"]["token"] != token:
-            return "unauthorized"
+            return {"banned": True, "username": user["username"]}
         else:
-            return True
+            return {"banned": False, "username": user["username"]}
 
     def verify_pswd(username, password):
         user = usersd.find_one({"username": username})
diff --git a/main.py b/main.py
index 9446b71..a92d4f9 100644
--- a/main.py
+++ b/main.py
@@ -8,6 +8,9 @@ from passlib.hash import scrypt
 import db

 import uuid

 import secrets

+import time

+

+version = "Helium-0.0.0a"

 

 addr = "localhost"

 port = 3636

@@ -20,7 +23,9 @@ error_contexts = {
     "invalidUsername": "Username is invalid. It may contain characters that are not permitted in usernames.",

     "invalidInvite": "The invite code you are trying to use is invalid or has expired.",

     "usernameTaken": "This username has been taken.",

-    "notExists": "The requested value does not exist."

+    "notExists": "The requested value does not exist.",

+    "lockdown": "Maintenance is in progress.",

+    "authed": "You are already authenticated."

 }

 

 ulist = {}

@@ -31,18 +36,21 @@ invite_codes = ["STANLEYYELNATSAB"]
 locked = False

 

 class util:

-    def error(code, listener):

+    def error(code, listener, data=None):

         if code in error_contexts:

             context = error_contexts[code]

         else:

             context = ""

-        return json.dumps({

+        response = {

             "error": True,

             "code": code,

             "form": "helium-util",

             "context": context,

             "listener": listener

-        })

+        }

+        if data:

+            response.update(data)

+        return json.dumps(response)

     

     def field_check(expects, gets):

         for i in expects:

@@ -52,9 +60,38 @@ class util:
                 if len(gets[i]) not in expects[i]:

                     return "lengthInvalid"

         return True

+    

+    def greeting():

+        return json.dumps({

+            "command": "greet",

+            "version": version,

+            "ulist": ulist,

+            "messages": [],

+            "locked": locked

+        })

+    

+    def ulist():

+        broadcast(clients, json.dumps({

+            "command": "ulist",

+            "ulist": ulist

+        }))

+    

+    def author_data(username):

+        data = db.acc.get(username)

+        del data["secure"]

+        del data["profile"]

+        return data

+    

+    def authorize(username, conn_id, client=""):

+        ulist[username] = client

+        user_clients[conn_id] = {"username": username, "client": client}

+        data = db.acc.get(username)

+        del data["secure"]

+        return data

 

 async def handler(websocket):

     clients.append(websocket)

+    await websocket.send(util.greeting())

     async for message in websocket:

         try:

             r = json.loads(message)

@@ -72,6 +109,12 @@ async def handler(websocket):
             if fc != True:

                 await websocket.send(util.error(fc, listener))

                 continue

+            if str(websocket.id) in user_clients:

+                await websocket.send(util.error("authed", listener))

+                continue

+            if locked:

+                await websocket.send(util.error("lockdown", listener))

+                continue

             r["username"] = r["username"].lower()

             r["invite_code"] = r["invite_code"].upper()

             if not re.fullmatch("[a-z0-9-_]{1,20}", r["username"]):

@@ -87,10 +130,12 @@ async def handler(websocket):
                 "_id": str(uuid.uuid4()),

                 "username": r["username"],

                 "display_name": r["username"],

+                "created": round(time.time()),

                 "avatar": None,

                 "bot": False,

                 "verified": False,

                 "banned_until": 0,

+                "permissions": [],

                 "profile": {

                     "bio": "",

                     "lastfm": "",

@@ -116,8 +161,56 @@ async def handler(websocket):
             if fc != True:

                 await websocket.send(util.error(fc, listener))

                 continue

+            if str(websocket.id) in user_clients:

+                await websocket.send(util.error("authed", listener))

+                continue

             r["username"] = r["username"].lower()

-

+            if locked:

+                perms = db.acc.get_perms(r["username"])

+                if type(perms) != list:

+                    await websocket.send(util.error("lockdown", listener))

+                    continue

+                if "LOCK" not in perms:

+                    await websocket.send(util.error("lockdown", listener))

+                    continue

+            valid = db.acc.verify_pswd(r["username"], r["password"])

+            if type(valid) == dict:

+                userdata = util.authorize(r["username"], str(websocket.id))

+                await websocket.send(json.dumps({"error": False, "token": valid["token"], "user": userdata, "listener": listener}))

+                util.ulist()

+                continue

+            elif valid == "banned":

+                await websocket.send(util.error(valid, listener, db.acc.get_ban(r["username"])))

+                continue

+            else:

+                await websocket.send(util.error(valid, listener))

+                continue

+        elif r["command"] == "login_token":

+            fc = util.field_check({"token": range(32,128)}, r)

+            if fc != True:

+                await websocket.send(util.error(fc, listener))

+                continue

+            if str(websocket.id) in user_clients:

+                await websocket.send(util.error("authed", listener))

+                continue

+            if locked:

+                await websocket.send(util.error("lockdown", listener))

+                continue

+            valid = db.acc.verify(r["token"])

+            if type(valid) == dict:

+                if valid["banned"]:

+                    await websocket.send(util.error("banned", listener, db.acc.get_ban(valid["username"])))

+                    continue

+                else:

+                    userdata = util.authorize(valid["username"], str(websocket.id))

+                    await websocket.send(json.dumps({"error": False, "user": userdata, "listener": listener}))

+                    util.ulist()

+                    continue

+            else:

+                await websocket.send(util.error(valid, listener))

+                continue

+        elif r["command"] == "ping":

+            pass

         else:

             await websocket.send(util.error("malformedJson", listener))

     if websocket in clients: