cspj-application/server-ml/main.py

39 lines
1.3 KiB
Python
Raw Normal View History

2025-02-06 01:28:40 +08:00
# 1=sql injection query and 0=normal sql query
from unsloth import FastLanguageModel
from transformers import AutoTokenizer
# Load the model and tokenizer
model_name = "shukdevdatta123/sql_injection_classifier_DeepSeek_R1_fine_tuned_model"
hf_token = "your hf tokens"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
load_in_4bit=True,
token=hf_token,
)
# Function for testing queries
def predict_sql_injection(query):
# Prepare the model for inference
inference_model = FastLanguageModel.for_inference(model)
prompt = f"### Instruction:\nClassify the following SQL query as normal (0) or an injection attack (1).\n\n### Query:\n{query}\n\n### Classification:\n"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# Use the inference model for generation
outputs = inference_model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=1000,
use_cache=True,
)
prediction = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return prediction.split("### Classification:\n")[-1].strip()
# Example usage
test_query = "SELECT * FROM users WHERE id = '1' OR '1'='1' --"
result = predict_sql_injection(test_query)
print(f"Query: {test_query}\nPrediction: {result}")