61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
from flask import Flask, request, jsonify
|
|
import torch
|
|
from transformers import MobileBertTokenizer, MobileBertForSequenceClassification
|
|
|
|
print("Starting server...")
|
|
|
|
# initialize Flask app
|
|
app = Flask(__name__)
|
|
|
|
# set device, use gpu if available
|
|
device = torch.device("cpu")
|
|
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# load tokenizer and model
|
|
print("Loading model...")
|
|
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")
|
|
model = MobileBertForSequenceClassification.from_pretrained("cssupport/mobilebert-sql-injection-detect")
|
|
model.to(device)
|
|
model.eval()
|
|
print("Model loaded")
|
|
|
|
# function for model to predict sql injection
|
|
def predict(text):
|
|
inputs = tokenizer(text, padding=False, truncation=True, return_tensors="pt", max_length=512)
|
|
input_ids = inputs["input_ids"].to(device)
|
|
attention_mask = inputs["attention_mask"].to(device)
|
|
|
|
with torch.no_grad():
|
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
logits = outputs.logits
|
|
probabilities = torch.softmax(logits, dim=1)
|
|
predicted_class = torch.argmax(probabilities, dim=1).item()
|
|
confidence = probabilities[0][predicted_class].item()
|
|
|
|
return predicted_class, confidence
|
|
|
|
# the api endpoint
|
|
@app.route("/predict", methods=["POST"])
|
|
def classify_query():
|
|
data = request.json
|
|
if "query" not in data:
|
|
return jsonify({"error": "Missing 'query' in request"}), 400
|
|
|
|
query = data["query"]
|
|
predicted_class, confidence = predict(query)
|
|
|
|
# if >0.7, then mark as bad
|
|
is_vulnerable = predicted_class == 1 and confidence > 0.7
|
|
result = {
|
|
"query": query,
|
|
"result": "fail" if is_vulnerable else "pass",
|
|
"confidence": round(confidence, 2)
|
|
}
|
|
|
|
return jsonify(result)
|
|
|
|
# run the flask server
|
|
if __name__ == "__main__":
|
|
app.run(host="0.0.0.0", port=5000, debug=False)
|