changed response format
This commit is contained in:
parent
b2250966fb
commit
9bee70c106
|
@ -0,0 +1,46 @@
|
||||||
|
# Test Cases
|
||||||
|
|
||||||
|
## 1. Basic test cases
|
||||||
|
|
||||||
|
[Good] `SELECT _ FROM users WHERE id = 1`
|
||||||
|
[Bad] `SELECT _ FROM users WHERE id = 1 OR 1=1`
|
||||||
|
|
||||||
|
## 2. Authentication bypass cases
|
||||||
|
|
||||||
|
[Bad] `SELECT _ FROM users WHERE username = 'admin' --`
|
||||||
|
[Bad] `SELECT _ FROM users WHERE username = 'admin' #`
|
||||||
|
[Bad] `SELECT \* FROM users WHERE username = 'admin' OR '1'='1'`
|
||||||
|
|
||||||
|
## 3. Union based injection cases
|
||||||
|
|
||||||
|
[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT null, 'hacker'`
|
||||||
|
[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT 1, 'hacked' FROM dual`
|
||||||
|
[Bad] `SELECT database() UNION SELECT 1`
|
||||||
|
|
||||||
|
## 4. Error based injection cases
|
||||||
|
|
||||||
|
[Bad] `SELECT _ FROM users WHERE id = 1 AND (SELECT 1 FROM users WHERE id=2)=1`
|
||||||
|
[Bad] `SELECT _ FROM users WHERE id = (SELECT COUNT(\*) FROM users)`
|
||||||
|
|
||||||
|
## 5. Blind SQL injection cases
|
||||||
|
|
||||||
|
[Bad] `SELECT _ FROM users WHERE id = 1; WAITFOR DELAY '00:00:10' --`
|
||||||
|
[Bad] `SELECT _ FROM users WHERE username = 'admin' AND 1=1`
|
||||||
|
|
||||||
|
## 6. Hex and Base64 encoded injection cases
|
||||||
|
|
||||||
|
[Bad] `SELECT _ FROM users WHERE username = 0x61646D696E`
|
||||||
|
[Bad] `SELECT _ FROM users WHERE username = 'YWRtaW4='`
|
||||||
|
|
||||||
|
## 7. False positives cases
|
||||||
|
|
||||||
|
[Good] `SELECT _ FROM users WHERE id = 5`
|
||||||
|
[Good] `SELECT users.name, orders.amount FROM users JOIN orders ON users.id = orders.user_id`
|
||||||
|
[Good] `SELECT _ FROM users WHERE username = ? AND password = ?`
|
||||||
|
|
||||||
|
## 8. Edge cases
|
||||||
|
|
||||||
|
[Good] `""`
|
||||||
|
[Bad] `'; --`
|
||||||
|
[Good] `12345`
|
||||||
|
[Good] `asdkjhasdkjh`
|
|
@ -2,19 +2,24 @@ from flask import Flask, request, jsonify
|
||||||
import torch
|
import torch
|
||||||
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification
|
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification
|
||||||
|
|
||||||
# Initialize Flask app
|
print("Starting server...")
|
||||||
|
|
||||||
|
# initialize Flask app
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
# Set device (GPU if available, otherwise CPU)
|
# set device, use gpu if available
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cpu")
|
||||||
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Load tokenizer and model
|
# load tokenizer and model
|
||||||
|
print("Loading model...")
|
||||||
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
|
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
|
||||||
model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect")
|
model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
print("Model loaded")
|
||||||
|
|
||||||
# Function to predict SQL injection
|
# function for model to predict sql injection
|
||||||
def predict(text):
|
def predict(text):
|
||||||
inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512)
|
inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512)
|
||||||
input_ids = inputs["input_ids"].to(device)
|
input_ids = inputs["input_ids"].to(device)
|
||||||
|
@ -30,7 +35,7 @@ def predict(text):
|
||||||
|
|
||||||
return predicted_class, confidence
|
return predicted_class, confidence
|
||||||
|
|
||||||
# Define API endpoint
|
# the api endpoint
|
||||||
@app.route("/predict", methods=["POST"])
|
@app.route("/predict", methods=["POST"])
|
||||||
def classify_query():
|
def classify_query():
|
||||||
data = request.json
|
data = request.json
|
||||||
|
@ -40,16 +45,16 @@ def classify_query():
|
||||||
query = data["query"]
|
query = data["query"]
|
||||||
predicted_class, confidence = predict(query)
|
predicted_class, confidence = predict(query)
|
||||||
|
|
||||||
# Thresholding (if confidence > 0.7, mark as SQL Injection)
|
# if >0.7, then mark as bad
|
||||||
is_vulnerable = predicted_class == 1 and confidence > 0.7
|
is_vulnerable = predicted_class == 1 and confidence > 0.7
|
||||||
result = {
|
result = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"classification": "SQL Injection Detected" if is_vulnerable else "No SQL Injection Detected",
|
"result": "fail" if is_vulnerable else "pass",
|
||||||
"confidence": round(confidence, 2)
|
"confidence": round(confidence, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
return jsonify(result)
|
return jsonify(result)
|
||||||
|
|
||||||
# Run Flask server
|
# run the flask server
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
app.run(host="0.0.0.0", port=5000, debug=True)
|
app.run(host="0.0.0.0", port=5000, debug=False)
|
||||||
|
|
|
@ -3,7 +3,7 @@ name = "server-ml"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.11"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"ctransformers>=0.2.27",
|
"ctransformers>=0.2.27",
|
||||||
"flask>=3.1.0",
|
"flask>=3.1.0",
|
||||||
|
|
|
@ -1,3 +1,48 @@
|
||||||
flask
|
accelerate==1.3.0
|
||||||
scikit-learn
|
blinker==1.9.0
|
||||||
joblib
|
certifi==2025.1.31
|
||||||
|
charset-normalizer==3.4.1
|
||||||
|
click==8.1.8
|
||||||
|
ctransformers==0.2.27
|
||||||
|
filelock==3.17.0
|
||||||
|
flask==3.1.0
|
||||||
|
fsspec==2025.2.0
|
||||||
|
huggingface-hub==0.28.1
|
||||||
|
idna==3.10
|
||||||
|
itsdangerous==2.2.0
|
||||||
|
jinja2==3.1.5
|
||||||
|
markupsafe==3.0.2
|
||||||
|
mpmath==1.3.0
|
||||||
|
networkx==3.4.2
|
||||||
|
numpy==2.2.2
|
||||||
|
nvidia-cublas-cu12==12.4.5.8
|
||||||
|
nvidia-cuda-cupti-cu12==12.4.127
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.4.127
|
||||||
|
nvidia-cuda-runtime-cu12==12.4.127
|
||||||
|
nvidia-cudnn-cu12==9.1.0.70
|
||||||
|
nvidia-cufft-cu12==11.2.1.3
|
||||||
|
nvidia-curand-cu12==10.3.5.147
|
||||||
|
nvidia-cusolver-cu12==11.6.1.9
|
||||||
|
nvidia-cusparse-cu12==12.3.1.170
|
||||||
|
nvidia-cusparselt-cu12==0.6.2
|
||||||
|
nvidia-nccl-cu12==2.21.5
|
||||||
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
|
nvidia-nvtx-cu12==12.4.127
|
||||||
|
packaging==24.2
|
||||||
|
peft==0.14.0
|
||||||
|
psutil==6.1.1
|
||||||
|
py-cpuinfo==9.0.0
|
||||||
|
pyyaml==6.0.2
|
||||||
|
regex==2024.11.6
|
||||||
|
requests==2.32.3
|
||||||
|
safetensors==0.5.2
|
||||||
|
setuptools==75.8.0
|
||||||
|
sympy==1.13.1
|
||||||
|
tokenizers==0.21.0
|
||||||
|
torch==2.6.0
|
||||||
|
tqdm==4.67.1
|
||||||
|
transformers==4.48.2
|
||||||
|
triton==3.2.0
|
||||||
|
typing-extensions==4.12.2
|
||||||
|
urllib3==2.3.0
|
||||||
|
werkzeug==3.1.3
|
||||||
|
|
|
@ -1,46 +0,0 @@
|
||||||
# Test Cases
|
|
||||||
|
|
||||||
## 1. Basic test cases
|
|
||||||
|
|
||||||
[Good] `SELECT _ FROM users WHERE id = 1`
|
|
||||||
[Bad] `SELECT _ FROM users WHERE id = 1 OR 1=1`
|
|
||||||
|
|
||||||
## 2. Authentication bypass cases
|
|
||||||
|
|
||||||
[Bad] `SELECT _ FROM users WHERE username = 'admin' --`
|
|
||||||
[Bad] `SELECT _ FROM users WHERE username = 'admin' #`
|
|
||||||
[Bad] `SELECT \* FROM users WHERE username = 'admin' OR '1'='1'`
|
|
||||||
|
|
||||||
## 3. Union based injection cases
|
|
||||||
|
|
||||||
[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT null, 'hacker'`
|
|
||||||
[Bad] `SELECT id, username FROM users WHERE id = 1 UNION SELECT 1, 'hacked' FROM dual`
|
|
||||||
[Bad] `SELECT database() UNION SELECT 1`
|
|
||||||
|
|
||||||
## 4. Error based injection cases
|
|
||||||
|
|
||||||
[Bad] `SELECT _ FROM users WHERE id = 1 AND (SELECT 1 FROM users WHERE id=2)=1`
|
|
||||||
[Bad] `SELECT _ FROM users WHERE id = (SELECT COUNT(\*) FROM users)`
|
|
||||||
|
|
||||||
## 5. Blind SQL injection cases
|
|
||||||
|
|
||||||
[Bad] `SELECT _ FROM users WHERE id = 1; WAITFOR DELAY '00:00:10' --`
|
|
||||||
[Bad] `SELECT _ FROM users WHERE username = 'admin' AND 1=1`
|
|
||||||
|
|
||||||
## 6. Hex and Base64 encoded injection cases
|
|
||||||
|
|
||||||
[Bad] `SELECT _ FROM users WHERE username = 0x61646D696E`
|
|
||||||
[Bad] `SELECT _ FROM users WHERE username = 'YWRtaW4='`
|
|
||||||
|
|
||||||
## 7. False positives cases
|
|
||||||
|
|
||||||
[Good] `SELECT _ FROM users WHERE id = 5`
|
|
||||||
[Good] `SELECT users.name, orders.amount FROM users JOIN orders ON users.id = orders.user_id`
|
|
||||||
[Good] `SELECT _ FROM users WHERE username = ? AND password = ?`
|
|
||||||
|
|
||||||
## 8. Edge cases
|
|
||||||
|
|
||||||
[Good] `""`
|
|
||||||
[Bad] `'; --`
|
|
||||||
[Good] `12345`
|
|
||||||
[Good] `asdkjhasdkjh`
|
|
|
@ -1,35 +0,0 @@
|
||||||
import joblib
|
|
||||||
from sklearn.feature_extraction.text import CountVectorizer
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.pipeline import make_pipeline
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
|
|
||||||
# random data
|
|
||||||
data = [
|
|
||||||
("' OR '1'='1", 1),
|
|
||||||
("SELECT * FROM users WHERE id=1", 1),
|
|
||||||
("DROP TABLE users;", 1),
|
|
||||||
("username=admin'--", 1),
|
|
||||||
("hello world", 0),
|
|
||||||
("this is a normal query", 0),
|
|
||||||
("select data from table", 0),
|
|
||||||
("just another harmless input", 0),
|
|
||||||
]
|
|
||||||
|
|
||||||
queries, labels = zip(*data)
|
|
||||||
|
|
||||||
# split data into training and testing sets
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
|
||||||
queries, labels, test_size=0.2, random_state=42
|
|
||||||
)
|
|
||||||
|
|
||||||
# build a pipeline with a vectorizer and a logistic regression model
|
|
||||||
pipeline = make_pipeline(CountVectorizer(), LogisticRegression())
|
|
||||||
|
|
||||||
# train the model
|
|
||||||
pipeline.fit(X_train, y_train)
|
|
||||||
|
|
||||||
# save the model to a file
|
|
||||||
joblib.dump(pipeline, "model.pkl")
|
|
||||||
|
|
||||||
print("Model trained and saved to model.pkl")
|
|
Loading…
Reference in a new issue