From 9bee70c10600d48391e0fcb44dbdf13e2e771850 Mon Sep 17 00:00:00 2001 From: Vomitblood Date: Thu, 13 Feb 2025 02:44:00 +0800 Subject: [PATCH] changed response format --- server-ml/README.md | 46 ++++++++++++++++++++++++++++++++++ server-ml/main.py | 25 +++++++++++-------- server-ml/pyproject.toml | 2 +- server-ml/requirements.txt | 51 +++++++++++++++++++++++++++++++++++--- server-ml/test-cases.md | 46 ---------------------------------- server-ml/training.py | 35 -------------------------- 6 files changed, 110 insertions(+), 95 deletions(-) delete mode 100644 server-ml/test-cases.md delete mode 100644 server-ml/training.py diff --git a/server-ml/README.md b/server-ml/README.md index e69de29..5e58c7c 100644 --- a/server-ml/README.md +++ b/server-ml/README.md @@ -0,0 +1,46 @@ +# Test Cases + +## 1. Basic test cases + +[Good] `SELECT _ FROM users WHERE id = 1` +[Bad] `SELECT _ FROM users WHERE id = 1 OR 1=1` + +## 2. Authentication bypass cases + +[Bad] `SELECT _ FROM users WHERE username = 'admin' --` +[Bad] `SELECT _ FROM users WHERE username = 'admin' #` +[Bad] `SELECT \* FROM users WHERE username = 'admin' OR '1'='1'` + +## 3. Union based injection cases + +[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT null, 'hacker'` +[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT 1, 'hacked' FROM dual` +[Bad] `SELECT database() UNION SELECT 1` + +## 4. Error based injection cases + +[Bad] `SELECT _ FROM users WHERE id = 1 AND (SELECT 1 FROM users WHERE id=2)=1` +[Bad] `SELECT _ FROM users WHERE id = (SELECT COUNT(\*) FROM users)` + +## 5. Blind SQL injection cases + +[Bad] `SELECT _ FROM users WHERE id = 1; WAITFOR DELAY '00:00:10' --` +[Bad] `SELECT _ FROM users WHERE username = 'admin' AND 1=1` + +## 6. Hex and Base64 encoded injection cases + +[Bad] `SELECT _ FROM users WHERE username = 0x61646D696E` +[Bad] `SELECT _ FROM users WHERE username = 'YWRtaW4='` + +## 7. False positives cases + +[Good] `SELECT _ FROM users WHERE id = 5` +[Good] `SELECT users.name, orders.amount FROM users JOIN orders ON users.id = orders.user_id` +[Good] `SELECT _ FROM users WHERE username = ? AND password = ?` + +## 8. Edge cases + +[Good] `""` +[Bad] `'; --` +[Good] `12345` +[Good] `asdkjhasdkjh` diff --git a/server-ml/main.py b/server-ml/main.py index 676ebd0..0323b5c 100644 --- a/server-ml/main.py +++ b/server-ml/main.py @@ -2,19 +2,24 @@ from flask import Flask, request, jsonify import torch from transformers import MobileBertTokenizer, MobileBertForSequenceClassification -# Initialize Flask app +print("Starting server...") + +# initialize Flask app app = Flask(__name__) -# Set device (GPU if available, otherwise CPU) -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +# set device, use gpu if available +device = torch.device("cpu") +# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# Load tokenizer and model +# load tokenizer and model +print("Loading model...") tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased") model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect") model.to(device) model.eval() +print("Model loaded") -# Function to predict SQL injection +# function for model to predict sql injection def predict(text): inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512) input_ids = inputs["input_ids"].to(device) @@ -30,7 +35,7 @@ def predict(text): return predicted_class, confidence -# Define API endpoint +# the api endpoint @app.route("/predict", methods=["POST"]) def classify_query(): data = request.json @@ -40,16 +45,16 @@ def classify_query(): query = data["query"] predicted_class, confidence = predict(query) - # Thresholding (if confidence > 0.7, mark as SQL Injection) + # if >0.7, then mark as bad is_vulnerable = predicted_class == 1 and confidence > 0.7 result = { "query": query, - "classification": "SQL Injection Detected" if is_vulnerable else "No SQL Injection Detected", + "result": "fail" if is_vulnerable else "pass", "confidence": round(confidence, 2) } return jsonify(result) -# Run Flask server +# run the flask server if __name__ == "__main__": - app.run(host="0.0.0.0", port=5000, debug=True) + app.run(host="0.0.0.0", port=5000, debug=False) diff --git a/server-ml/pyproject.toml b/server-ml/pyproject.toml index 311b018..2980f65 100644 --- a/server-ml/pyproject.toml +++ b/server-ml/pyproject.toml @@ -3,7 +3,7 @@ name = "server-ml" version = "0.1.0" description = "Add your description here" readme = "README.md" -requires-python = ">=3.13" +requires-python = ">=3.11" dependencies = [ "ctransformers>=0.2.27", "flask>=3.1.0", diff --git a/server-ml/requirements.txt b/server-ml/requirements.txt index 0ca250e..4c3822b 100644 --- a/server-ml/requirements.txt +++ b/server-ml/requirements.txt @@ -1,3 +1,48 @@ -flask -scikit-learn -joblib \ No newline at end of file +accelerate==1.3.0 +blinker==1.9.0 +certifi==2025.1.31 +charset-normalizer==3.4.1 +click==8.1.8 +ctransformers==0.2.27 +filelock==3.17.0 +flask==3.1.0 +fsspec==2025.2.0 +huggingface-hub==0.28.1 +idna==3.10 +itsdangerous==2.2.0 +jinja2==3.1.5 +markupsafe==3.0.2 +mpmath==1.3.0 +networkx==3.4.2 +numpy==2.2.2 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparselt-cu12==0.6.2 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +packaging==24.2 +peft==0.14.0 +psutil==6.1.1 +py-cpuinfo==9.0.0 +pyyaml==6.0.2 +regex==2024.11.6 +requests==2.32.3 +safetensors==0.5.2 +setuptools==75.8.0 +sympy==1.13.1 +tokenizers==0.21.0 +torch==2.6.0 +tqdm==4.67.1 +transformers==4.48.2 +triton==3.2.0 +typing-extensions==4.12.2 +urllib3==2.3.0 +werkzeug==3.1.3 diff --git a/server-ml/test-cases.md b/server-ml/test-cases.md deleted file mode 100644 index 5e58c7c..0000000 --- a/server-ml/test-cases.md +++ /dev/null @@ -1,46 +0,0 @@ -# Test Cases - -## 1. Basic test cases - -[Good] `SELECT _ FROM users WHERE id = 1` -[Bad] `SELECT _ FROM users WHERE id = 1 OR 1=1` - -## 2. Authentication bypass cases - -[Bad] `SELECT _ FROM users WHERE username = 'admin' --` -[Bad] `SELECT _ FROM users WHERE username = 'admin' #` -[Bad] `SELECT \* FROM users WHERE username = 'admin' OR '1'='1'` - -## 3. Union based injection cases - -[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT null, 'hacker'` -[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT 1, 'hacked' FROM dual` -[Bad] `SELECT database() UNION SELECT 1` - -## 4. Error based injection cases - -[Bad] `SELECT _ FROM users WHERE id = 1 AND (SELECT 1 FROM users WHERE id=2)=1` -[Bad] `SELECT _ FROM users WHERE id = (SELECT COUNT(\*) FROM users)` - -## 5. Blind SQL injection cases - -[Bad] `SELECT _ FROM users WHERE id = 1; WAITFOR DELAY '00:00:10' --` -[Bad] `SELECT _ FROM users WHERE username = 'admin' AND 1=1` - -## 6. Hex and Base64 encoded injection cases - -[Bad] `SELECT _ FROM users WHERE username = 0x61646D696E` -[Bad] `SELECT _ FROM users WHERE username = 'YWRtaW4='` - -## 7. False positives cases - -[Good] `SELECT _ FROM users WHERE id = 5` -[Good] `SELECT users.name, orders.amount FROM users JOIN orders ON users.id = orders.user_id` -[Good] `SELECT _ FROM users WHERE username = ? AND password = ?` - -## 8. Edge cases - -[Good] `""` -[Bad] `'; --` -[Good] `12345` -[Good] `asdkjhasdkjh` diff --git a/server-ml/training.py b/server-ml/training.py deleted file mode 100644 index 74ad56b..0000000 --- a/server-ml/training.py +++ /dev/null @@ -1,35 +0,0 @@ -import joblib -from sklearn.feature_extraction.text import CountVectorizer -from sklearn.linear_model import LogisticRegression -from sklearn.pipeline import make_pipeline -from sklearn.model_selection import train_test_split - -# random data -data = [ - ("' OR '1'='1", 1), - ("SELECT * FROM users WHERE id=1", 1), - ("DROP TABLE users;", 1), - ("username=admin'--", 1), - ("hello world", 0), - ("this is a normal query", 0), - ("select data from table", 0), - ("just another harmless input", 0), -] - -queries, labels = zip(*data) - -# split data into training and testing sets -X_train, X_test, y_train, y_test = train_test_split( - queries, labels, test_size=0.2, random_state=42 -) - -# build a pipeline with a vectorizer and a logistic regression model -pipeline = make_pipeline(CountVectorizer(), LogisticRegression()) - -# train the model -pipeline.fit(X_train, y_train) - -# save the model to a file -joblib.dump(pipeline, "model.pkl") - -print("Model trained and saved to model.pkl")