Initial commit: Unity WordConnect project
This commit is contained in:
908
Assets/WordConnectGameToolkit/Scripts/NLP/ModelController.cs
Normal file
908
Assets/WordConnectGameToolkit/Scripts/NLP/ModelController.cs
Normal file
@ -0,0 +1,908 @@
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using Unity.Sentis;
|
||||
using UnityEngine;
|
||||
using System.Text;
|
||||
using WordsToolkit.Scripts.Levels;
|
||||
using WordsToolkit.Scripts.Services;
|
||||
using WordsToolkit.Scripts.Services.BannedWords;
|
||||
using VContainer;
|
||||
using Random = System.Random;
|
||||
|
||||
namespace WordsToolkit.Scripts.NLP
|
||||
{
|
||||
public class ModelController : IModelController
|
||||
{
|
||||
private ILanguageService languageService;
|
||||
private LanguageConfiguration languageConfiguration;
|
||||
|
||||
[SerializeField]
|
||||
Dictionary<string, ModelAsset> languageModels = new Dictionary<string, ModelAsset>();
|
||||
private IBannedWordsService bannedWordsService;
|
||||
|
||||
private Dictionary<string, Worker> m_Workers = new Dictionary<string, Worker>();
|
||||
private Dictionary<string, Dictionary<string, int>> wordToIndexByLanguage = new Dictionary<string, Dictionary<string, int>>();
|
||||
private Dictionary<string, int> m_VectorDimensionByLanguage = new Dictionary<string, int>();
|
||||
private Dictionary<string, int> m_VocabSizeByLanguage = new Dictionary<string, int>();
|
||||
private Dictionary<string, Tensor<int>> m_InputsByLanguage = new Dictionary<string, Tensor<int>>();
|
||||
|
||||
private string m_DefaultLanguage = "en";
|
||||
public int VectorDimension = 100;
|
||||
|
||||
// Protection flag to prevent accidental binary overwrite when you have custom words
|
||||
// NOTE: This is now mainly for the old SaveModelBinary method - new architecture uses custom words files
|
||||
private bool protectBinaryFile = false;
|
||||
|
||||
public bool IsModelLoaded(string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
return m_Workers.ContainsKey(language) &&
|
||||
wordToIndexByLanguage.ContainsKey(language) &&
|
||||
wordToIndexByLanguage[language].Count > 0;
|
||||
}
|
||||
|
||||
public ModelController(IBannedWordsService bannedWordsService,
|
||||
LanguageConfiguration languageConfiguration,
|
||||
ILanguageService languageService = null)
|
||||
{
|
||||
this.languageService = languageService;
|
||||
this.languageConfiguration = languageConfiguration;
|
||||
this.bannedWordsService = bannedWordsService;
|
||||
InitializeFromConfiguration();
|
||||
LoadModels();
|
||||
}
|
||||
|
||||
public IEnumerable<string> AvailableLanguages => m_Workers.Keys;
|
||||
|
||||
private void InitializeFromConfiguration()
|
||||
{
|
||||
// Use LanguageService to get current language if available, fallback to configuration default
|
||||
if (languageService != null)
|
||||
{
|
||||
m_DefaultLanguage = languageService.GetCurrentLanguageCode();
|
||||
}
|
||||
else
|
||||
{
|
||||
m_DefaultLanguage = languageConfiguration?.defaultLanguage ?? "en";
|
||||
}
|
||||
|
||||
languageModels.Clear();
|
||||
|
||||
foreach (var langInfo in languageConfiguration.languages)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(langInfo.code) && langInfo.languageModel != null)
|
||||
{
|
||||
languageModels[langInfo.code] = langInfo.languageModel;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void LoadModels()
|
||||
{
|
||||
foreach (var languagePair in languageModels)
|
||||
{
|
||||
LoadModelBin(languagePair.Key, languagePair.Value);
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Sets whether to protect existing binary files from being overwritten.
|
||||
/// When true, LoadModel() won't overwrite existing .bin files that might contain custom words.
|
||||
/// </summary>
|
||||
public void SetBinaryFileProtection(bool protect)
|
||||
{
|
||||
protectBinaryFile = protect;
|
||||
Debug.Log($"[ModelController] Binary file protection {(protect ? "ENABLED" : "DISABLED")}");
|
||||
}
|
||||
|
||||
public void LoadModelBin(string language, ModelAsset modelAsset)
|
||||
{
|
||||
// Always load the model structure first
|
||||
var model = ModelLoader.Load(modelAsset);
|
||||
|
||||
// (Re)create worker + input tensor
|
||||
if (m_Workers.ContainsKey(language))
|
||||
{
|
||||
m_Workers[language].Dispose();
|
||||
m_InputsByLanguage[language]?.Dispose();
|
||||
}
|
||||
|
||||
m_Workers[language] = new Worker(model, BackendType.CPU);
|
||||
m_InputsByLanguage[language] = new Tensor<int>(new TensorShape(1));
|
||||
|
||||
// Always load base vocabulary from ONNX JSON (no caching)
|
||||
var worker = m_Workers[language];
|
||||
var dummyInput = new Tensor<int>(new TensorShape(1));
|
||||
dummyInput[0] = 0;
|
||||
|
||||
try
|
||||
{
|
||||
worker.Schedule(dummyInput);
|
||||
var vocabJsonTensor = worker.PeekOutput("wc_vocab_json");
|
||||
|
||||
if (vocabJsonTensor != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
var byteTensor = vocabJsonTensor as Tensor<byte>;
|
||||
if (byteTensor != null)
|
||||
{
|
||||
var jsonData = new byte[byteTensor.shape.length];
|
||||
var cpuTensor = byteTensor.ReadbackAndClone();
|
||||
try
|
||||
{
|
||||
for (int i = 0; i < jsonData.Length; i++)
|
||||
jsonData[i] = cpuTensor[i];
|
||||
|
||||
string jsonString = Encoding.UTF8.GetString(jsonData);
|
||||
ParseAndLoadVocabulary(language, jsonString);
|
||||
}
|
||||
finally
|
||||
{
|
||||
cpuTensor.Dispose();
|
||||
}
|
||||
}
|
||||
else if (vocabJsonTensor is Tensor<int> intTensor)
|
||||
{
|
||||
var cpuTensor = intTensor.ReadbackAndClone();
|
||||
try
|
||||
{
|
||||
var jsonBytes = new List<byte>();
|
||||
for (int i = 0; i < cpuTensor.shape.length; i++)
|
||||
{
|
||||
int value = cpuTensor[i];
|
||||
if (value == 0)
|
||||
break;
|
||||
if (value >= 0 && value <= 255)
|
||||
jsonBytes.Add((byte)value);
|
||||
}
|
||||
|
||||
if (jsonBytes.Count > 0)
|
||||
{
|
||||
string jsonString = Encoding.UTF8.GetString(jsonBytes.ToArray());
|
||||
ParseAndLoadVocabulary(language, jsonString);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
cpuTensor.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
vocabJsonTensor.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new Exception($"Error loading base vocabulary for {language}: {e.Message}", e);
|
||||
}
|
||||
finally
|
||||
{
|
||||
dummyInput.Dispose();
|
||||
}
|
||||
|
||||
// Now load any custom words from binary file
|
||||
LoadCustomWordsFromBinary(language);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Loads custom words from binary file and adds them to the existing vocabulary.
|
||||
/// Binary file contains ONLY custom words, not the entire model cache.
|
||||
/// </summary>
|
||||
private void LoadCustomWordsFromBinary(string language)
|
||||
{
|
||||
string path = Path.Combine(Application.dataPath, "WordsToolkit", "model",
|
||||
"custom", $"{language}_custom_words.bin");
|
||||
|
||||
if (!File.Exists(path))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (!wordToIndexByLanguage.ContainsKey(language))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
using var fs = new FileStream(path, FileMode.Open, FileAccess.Read);
|
||||
using var br = new BinaryReader(fs, Encoding.UTF8);
|
||||
|
||||
// Read header
|
||||
if (br.ReadInt32() != 0x43555354) // "CUST" magic number
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int baseVocabSize = br.ReadInt32(); // Original vocab size when custom words were added
|
||||
int customWordCount = br.ReadInt32(); // Number of custom words
|
||||
int vectorDim = br.ReadInt32(); // Vector dimension
|
||||
|
||||
// Verify compatibility
|
||||
if (vectorDim != m_VectorDimensionByLanguage[language])
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// Load custom words
|
||||
int currentVocabSize = wordToIndexByLanguage[language].Count;
|
||||
int nextIndex = currentVocabSize;
|
||||
|
||||
for (int i = 0; i < customWordCount; i++)
|
||||
{
|
||||
string word = br.ReadString();
|
||||
|
||||
// Add to vocabulary
|
||||
wordToIndexByLanguage[language][word] = nextIndex++;
|
||||
}
|
||||
|
||||
// Update vocabulary size
|
||||
m_VocabSizeByLanguage[language] = nextIndex;
|
||||
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Debug.LogError($"[ModelController] Error loading custom words for '{language}': {e.Message}");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Saves only the custom words (not base vocabulary) to binary file.
|
||||
/// This creates a lightweight file with just the added words.
|
||||
/// </summary>
|
||||
private void SaveCustomWordsToBinary(string language)
|
||||
{
|
||||
if (!wordToIndexByLanguage.ContainsKey(language))
|
||||
{
|
||||
Debug.LogError($"[ModelController] No vocabulary loaded for '{language}' - cannot save custom words");
|
||||
return;
|
||||
}
|
||||
|
||||
string dir = Path.Combine(Application.dataPath, "WordsToolkit", "model", "custom");
|
||||
string path = Path.Combine(dir, $"{language}_custom_words.bin");
|
||||
Directory.CreateDirectory(dir);
|
||||
|
||||
try
|
||||
{
|
||||
// Get all custom words (assuming they have higher indices than base vocabulary)
|
||||
var vocab = wordToIndexByLanguage[language];
|
||||
var baseVocabSize = GetEstimatedBaseVocabSize(language);
|
||||
var customWords = vocab.Where(kvp => kvp.Value >= baseVocabSize)
|
||||
.OrderBy(kvp => kvp.Value)
|
||||
.ToList();
|
||||
|
||||
if (customWords.Count == 0)
|
||||
{
|
||||
Debug.Log($"[ModelController] No custom words to save for '{language}'");
|
||||
return;
|
||||
}
|
||||
|
||||
using var fs = new FileStream(path, FileMode.Create, FileAccess.Write);
|
||||
using var bw = new BinaryWriter(fs, Encoding.UTF8);
|
||||
|
||||
// Write header
|
||||
bw.Write(0x43555354); // "CUST" magic number
|
||||
bw.Write(baseVocabSize); // Original vocab size when custom words were added
|
||||
bw.Write(customWords.Count); // Number of custom words
|
||||
bw.Write(m_VectorDimensionByLanguage[language]); // Vector dimension
|
||||
|
||||
// Write custom words (just the words, not embeddings since we can't access them easily)
|
||||
foreach (var kvp in customWords)
|
||||
{
|
||||
bw.Write(kvp.Key); // Word string
|
||||
}
|
||||
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Debug.LogError($"[ModelController] Error saving custom words for '{language}': {e.Message}");
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Estimates the base vocabulary size by looking at the original model.
|
||||
/// This is a heuristic - in a real implementation you might want to store this value.
|
||||
/// </summary>
|
||||
private int GetEstimatedBaseVocabSize(string language)
|
||||
{
|
||||
// Try to load the original model and get its vocabulary size
|
||||
if (languageModels.TryGetValue(language, out var modelAsset) && modelAsset != null)
|
||||
{
|
||||
try
|
||||
{
|
||||
var model = ModelLoader.Load(modelAsset);
|
||||
using var worker = new Worker(model, BackendType.CPU);
|
||||
using var dummyInput = new Tensor<int>(new TensorShape(1));
|
||||
dummyInput[0] = 0;
|
||||
|
||||
worker.Schedule(dummyInput);
|
||||
var vocabJsonTensor = worker.PeekOutput("wc_vocab_json");
|
||||
|
||||
if (vocabJsonTensor != null)
|
||||
{
|
||||
// Parse JSON to get original vocab size
|
||||
// This is simplified - you'd need to extract and parse the JSON
|
||||
vocabJsonTensor.Dispose();
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Debug.LogWarning($"[ModelController] Could not determine base vocab size for '{language}': {e.Message}");
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback heuristic - assume custom words start after a reasonable base size
|
||||
var currentSize = wordToIndexByLanguage[language].Count;
|
||||
return currentSize > 1000 ? currentSize - 100 : Math.Max(1, currentSize / 2);
|
||||
}
|
||||
|
||||
private void ParseAndLoadVocabulary(string language, string jsonString)
|
||||
{
|
||||
try
|
||||
{
|
||||
var jsonObject = JObject.Parse(jsonString);
|
||||
var wordToIndex = new Dictionary<string, int>();
|
||||
var wordIndexDict = jsonObject["word_to_index"].ToObject<Dictionary<string, int>>();
|
||||
|
||||
foreach (var pair in wordIndexDict)
|
||||
{
|
||||
wordToIndex[pair.Key] = pair.Value;
|
||||
}
|
||||
|
||||
|
||||
int vectorDimension = jsonObject["vector_size"].Value<int>();
|
||||
int vocabSize = jsonObject["vocab_size"].Value<int>();
|
||||
|
||||
wordToIndexByLanguage[language] = wordToIndex;
|
||||
m_VectorDimensionByLanguage[language] = vectorDimension;
|
||||
m_VocabSizeByLanguage[language] = vocabSize;
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new Exception($"Error parsing vocabulary JSON for {language}: {e.Message}", e);
|
||||
}
|
||||
}
|
||||
|
||||
public float[] GetWordVector(string word, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (!wordToIndexByLanguage[language].ContainsKey(word))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
m_InputsByLanguage[language][0] = wordToIndexByLanguage[language][word];
|
||||
m_Workers[language].Schedule(m_InputsByLanguage[language]);
|
||||
|
||||
var outputTensor = m_Workers[language].PeekOutput() as Tensor<float>;
|
||||
if (outputTensor == null)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
var cpuTensor = outputTensor.ReadbackAndClone();
|
||||
try
|
||||
{
|
||||
float[] result = new float[m_VectorDimensionByLanguage[language]];
|
||||
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
|
||||
{
|
||||
result[i] = cpuTensor[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
finally
|
||||
{
|
||||
cpuTensor.Dispose();
|
||||
outputTensor.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
public bool IsWordKnown(string word, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
if (bannedWordsService.IsWordBanned(word, language))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
float[] vector = GetWordVector(word, language);
|
||||
if (vector == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private bool IsZeroVector(float[] vector)
|
||||
{
|
||||
foreach (float value in vector)
|
||||
{
|
||||
if (!Mathf.Approximately(value, 0f))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
public float GetCosineSimilarity(string word1, string word2, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return -1f;
|
||||
}
|
||||
|
||||
float[] vector1 = GetWordVector(word1, language);
|
||||
float[] vector2 = GetWordVector(word2, language);
|
||||
|
||||
if (vector1 == null || vector2 == null)
|
||||
return -1f;
|
||||
|
||||
return CosineSimilarity(vector1, vector2);
|
||||
}
|
||||
|
||||
public bool AddWordAndSave(string newWord, string language = null)
|
||||
{
|
||||
language ??= languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage;
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
Debug.LogWarning($"[ModelController] AddWord failed – model for '{language}' not loaded.");
|
||||
return false;
|
||||
}
|
||||
|
||||
if (wordToIndexByLanguage[language].ContainsKey(newWord))
|
||||
{
|
||||
Debug.LogWarning($"[ModelController] Word '{newWord}' already exists in vocab.");
|
||||
return false;
|
||||
}
|
||||
|
||||
var newVector = new float[m_VectorDimensionByLanguage[language]];
|
||||
int dim = m_VectorDimensionByLanguage[language];
|
||||
if (newVector == null || newVector.Length != dim)
|
||||
{
|
||||
Debug.LogWarning($"[ModelController] Vector length mismatch (expected {dim}).");
|
||||
return false;
|
||||
}
|
||||
|
||||
// 1️⃣ Update dictionaries
|
||||
int newIndex = m_VocabSizeByLanguage[language];
|
||||
wordToIndexByLanguage[language][newWord] = newIndex;
|
||||
m_VocabSizeByLanguage[language] = newIndex + 1;
|
||||
|
||||
// 2️⃣ Load the Model anew so we can mutate its constants safely
|
||||
if (!languageModels.TryGetValue(language, out var modelAsset) || modelAsset == null)
|
||||
{
|
||||
Debug.LogError($"[ModelController] Missing ModelAsset for language '{language}'.");
|
||||
return false;
|
||||
}
|
||||
|
||||
var model = ModelLoader.Load(modelAsset);
|
||||
|
||||
// We assume the <b>first</b> constant is the [vocab, dim] embedding matrix.
|
||||
// If your graph changes, adjust the index or name match here.
|
||||
if (model.constants == null || model.constants.Count == 0)
|
||||
{
|
||||
Debug.LogError("[ModelController] Model has no constants — cannot extend.");
|
||||
return false;
|
||||
}
|
||||
|
||||
var embConst = model.constants[0];
|
||||
if (embConst.dataType != DataType.Float || embConst.shape.rank != 2 || embConst.shape[1] != dim)
|
||||
{
|
||||
Debug.LogError("[ModelController] Unexpected embedding tensor layout.");
|
||||
return false;
|
||||
}
|
||||
|
||||
int oldVocab = embConst.shape[0];
|
||||
int oldElems = embConst.shape.length; // vocab * dim
|
||||
float[] oldBuf = new float[oldElems];
|
||||
NativeTensorArray.Copy(embConst.weights, 0, oldBuf, 0, oldElems);
|
||||
|
||||
// Build new buffer = old + newVector
|
||||
var newBuf = new float[oldElems + dim];
|
||||
Buffer.BlockCopy(oldBuf, 0, newBuf, 0, oldElems * sizeof(float));
|
||||
Buffer.BlockCopy(newVector,0, newBuf, oldElems * sizeof(float), dim * sizeof(float));
|
||||
|
||||
// Sentis requires a non‑generic NativeTensorArrayFromManagedArray
|
||||
// Sentis requires (Array, bytesPerElem, length, channels)
|
||||
// ctor args: (Array data, int srcElementOffset, int srcElementSize, int numDestElement)
|
||||
var newWeights = new NativeTensorArrayFromManagedArray(
|
||||
newBuf, // managed float[]
|
||||
0, // start at element 0
|
||||
sizeof(float),
|
||||
newBuf.Length); // total elements
|
||||
#pragma warning disable 618 // Constant.weights setter is obsolete but still functional
|
||||
embConst.weights = newWeights;
|
||||
#pragma warning restore 618
|
||||
embConst.shape = new TensorShape(oldVocab + 1, dim); // update shape metadata
|
||||
|
||||
// 3️⃣ Write new BIN
|
||||
SaveCustomWordsToBinary(language);
|
||||
|
||||
Debug.Log($"[ModelController] Successfully added word '{newWord}' to {language} vocabulary at index {newIndex}");
|
||||
return true;
|
||||
}
|
||||
|
||||
private float CosineSimilarity(float[] v1, float[] v2)
|
||||
{
|
||||
float dotProduct = 0;
|
||||
float norm1 = 0;
|
||||
float norm2 = 0;
|
||||
|
||||
for (int i = 0; i < v1.Length; i++)
|
||||
{
|
||||
dotProduct += v1[i] * v2[i];
|
||||
norm1 += v1[i] * v1[i];
|
||||
norm2 += v2[i] * v2[i];
|
||||
}
|
||||
|
||||
return dotProduct / (Mathf.Sqrt(norm1) * Mathf.Sqrt(norm2));
|
||||
}
|
||||
|
||||
private void OnDisable()
|
||||
{
|
||||
foreach (var worker in m_Workers.Values)
|
||||
{
|
||||
worker?.Dispose();
|
||||
}
|
||||
|
||||
foreach (var input in m_InputsByLanguage.Values)
|
||||
{
|
||||
input?.Dispose();
|
||||
}
|
||||
|
||||
m_Workers.Clear();
|
||||
m_InputsByLanguage.Clear();
|
||||
wordToIndexByLanguage.Clear();
|
||||
m_VectorDimensionByLanguage.Clear();
|
||||
m_VocabSizeByLanguage.Clear();
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
OnDisable();
|
||||
}
|
||||
|
||||
public List<string> GetRandomWords(int wordCount, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
var result = new List<string>();
|
||||
var words = new List<string>(wordToIndexByLanguage[language].Keys)
|
||||
.Where(word => !bannedWordsService.IsWordBanned(word, language))
|
||||
.ToList();
|
||||
var random = new Random();
|
||||
|
||||
wordCount = Mathf.Min(wordCount, words.Count);
|
||||
|
||||
while (result.Count < wordCount)
|
||||
{
|
||||
int randomIndex = random.Next(words.Count);
|
||||
string word = words[randomIndex];
|
||||
|
||||
if (!result.Contains(word))
|
||||
result.Add(word);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public List<string> GetAllWords(string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
return new List<string>(wordToIndexByLanguage[language].Keys)
|
||||
.Where(word => !bannedWordsService.IsWordBanned(word, language))
|
||||
.ToList();
|
||||
}
|
||||
|
||||
public List<string> GetWordsFromSymbols(string inputSymbols, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if(!IsModelLoaded(language))
|
||||
{
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(inputSymbols))
|
||||
return new List<string>();
|
||||
|
||||
inputSymbols = inputSymbols.ToLower();
|
||||
Dictionary<char, int> charCounts = new Dictionary<char, int>();
|
||||
foreach (char c in inputSymbols)
|
||||
{
|
||||
if (charCounts.ContainsKey(c))
|
||||
charCounts[c]++;
|
||||
else
|
||||
charCounts[c] = 1;
|
||||
}
|
||||
|
||||
var candidateWords = new List<string>();
|
||||
foreach (var word in wordToIndexByLanguage[language].Keys)
|
||||
{
|
||||
if (bannedWordsService != null && bannedWordsService.IsWordBanned(word, language))
|
||||
continue;
|
||||
|
||||
Dictionary<char, int> remainingCounts = new Dictionary<char, int>(charCounts);
|
||||
bool isValid = true;
|
||||
|
||||
foreach (char c in word)
|
||||
{
|
||||
if (!remainingCounts.ContainsKey(c) || remainingCounts[c] <= 0)
|
||||
{
|
||||
isValid = false;
|
||||
break;
|
||||
}
|
||||
remainingCounts[c]--;
|
||||
}
|
||||
|
||||
if (isValid)
|
||||
{
|
||||
candidateWords.Add(word);
|
||||
}
|
||||
}
|
||||
|
||||
if(!candidateWords.Contains(inputSymbols) &&
|
||||
!bannedWordsService.IsWordBanned(inputSymbols, language) && IsWordKnown(inputSymbols, language))
|
||||
{
|
||||
candidateWords.Add(inputSymbols);
|
||||
}
|
||||
|
||||
float[] referenceVector = CreateInputTensor(inputSymbols, language);
|
||||
if (referenceVector != null && candidateWords.Count > 0)
|
||||
{
|
||||
return RankWordsBySimilarity(candidateWords, referenceVector, 200, language);
|
||||
}
|
||||
|
||||
return candidateWords
|
||||
.OrderByDescending(w => w.Length)
|
||||
.ToList();
|
||||
}
|
||||
|
||||
public float[] CreateInputTensor(string inputSymbols, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
if (string.IsNullOrEmpty(inputSymbols))
|
||||
return null;
|
||||
|
||||
var symbolSet = new HashSet<char>(inputSymbols.ToLower());
|
||||
|
||||
var bestMatches = wordToIndexByLanguage[language].Keys
|
||||
.Select(word => new {
|
||||
Word = word,
|
||||
SharedChars = word.Count(c => symbolSet.Contains(c))
|
||||
})
|
||||
.OrderByDescending(x => (float)x.SharedChars / x.Word.Length)
|
||||
.Take(5)
|
||||
.ToList();
|
||||
|
||||
if (bestMatches.Count == 0)
|
||||
return null;
|
||||
|
||||
float[] compositeVector = new float[m_VectorDimensionByLanguage[language]];
|
||||
foreach (var match in bestMatches)
|
||||
{
|
||||
float[] wordVector = GetWordVector(match.Word, language);
|
||||
if (wordVector != null)
|
||||
{
|
||||
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
|
||||
{
|
||||
compositeVector[i] += wordVector[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float magnitude = 0;
|
||||
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
|
||||
{
|
||||
magnitude += compositeVector[i] * compositeVector[i];
|
||||
}
|
||||
|
||||
magnitude = Mathf.Sqrt(magnitude);
|
||||
if (magnitude > 0)
|
||||
{
|
||||
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
|
||||
{
|
||||
compositeVector[i] /= magnitude;
|
||||
}
|
||||
}
|
||||
|
||||
return compositeVector;
|
||||
}
|
||||
|
||||
private List<string> RankWordsBySimilarity(List<string> candidateWords, float[] referenceVector, int maxResults, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
var scoredWords = new List<(string word, float score)>();
|
||||
|
||||
foreach (var word in candidateWords)
|
||||
{
|
||||
float[] wordVector = GetWordVector(word, language);
|
||||
if (wordVector != null)
|
||||
{
|
||||
float similarity = CosineSimilarity(referenceVector, wordVector);
|
||||
float adjustedScore = similarity * (0.8f + 0.2f * Mathf.Min(1f, word.Length / 10f));
|
||||
scoredWords.Add((word, adjustedScore));
|
||||
}
|
||||
}
|
||||
|
||||
return scoredWords
|
||||
.OrderByDescending(pair => pair.score)
|
||||
.Take(maxResults)
|
||||
.Select(pair => pair.word)
|
||||
.ToList();
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets a list of words that are exactly the specified length.
|
||||
/// Returns complete words only, does not truncate or modify the words.
|
||||
/// </summary>
|
||||
/// <param name="length">The exact length of words to return</param>
|
||||
/// <param name="language">Language code, defaults to current language if not specified</param>
|
||||
/// <returns>List of complete words of the specified length</returns>
|
||||
public List<string> GetWordsWithLength(int length, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
if (!IsModelLoaded(language))
|
||||
{
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
if (length <= 0)
|
||||
{
|
||||
Debug.LogWarning($"Invalid word length requested: {length}. Length must be greater than 0.");
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
var random = new Random();
|
||||
var matchingWords = wordToIndexByLanguage[language].Keys
|
||||
.Where(word => word != null && word.Length == length)
|
||||
.Where(word => bannedWordsService == null || !bannedWordsService.IsWordBanned(word, language))
|
||||
.ToList();
|
||||
|
||||
if (matchingWords.Count == 0)
|
||||
{
|
||||
Debug.LogWarning($"No words found with exact length: {length}");
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
var sampleWords = matchingWords
|
||||
.OrderBy(word => random.Next())
|
||||
.ToList();
|
||||
|
||||
if (sampleWords.Count == 0)
|
||||
{
|
||||
return new List<string>();
|
||||
}
|
||||
|
||||
return sampleWords;
|
||||
}
|
||||
|
||||
private List<string> FindSimilarWords(float[] targetVector, int maxResults, string language = null)
|
||||
{
|
||||
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
|
||||
|
||||
var scoredWords = new List<(string word, float score)>();
|
||||
|
||||
foreach (var word in wordToIndexByLanguage[language].Keys)
|
||||
{
|
||||
if (bannedWordsService.IsWordBanned(word, language))
|
||||
continue;
|
||||
|
||||
float[] wordVector = GetWordVector(word, language);
|
||||
if (wordVector != null)
|
||||
{
|
||||
float similarity = CosineSimilarity(targetVector, wordVector);
|
||||
scoredWords.Add((word, similarity));
|
||||
}
|
||||
}
|
||||
|
||||
return scoredWords
|
||||
.OrderByDescending(pair => pair.score)
|
||||
.Take(maxResults)
|
||||
.Select(pair => pair.word)
|
||||
.ToList();
|
||||
}
|
||||
|
||||
public IEnumerable<string> GetAvailableLanguages()
|
||||
{
|
||||
return AvailableLanguages;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Clears custom words cache for a specific language or all languages.
|
||||
/// This only removes custom word files, not the base model data.
|
||||
/// </summary>
|
||||
/// <param name="language">Language to clear, or null to clear all</param>
|
||||
public void ClearCustomWordsCache(string language = null)
|
||||
{
|
||||
string customDir = Path.Combine(Application.dataPath, "WordsToolkit", "model", "custom");
|
||||
|
||||
if (!Directory.Exists(customDir))
|
||||
return;
|
||||
|
||||
try
|
||||
{
|
||||
if (language != null)
|
||||
{
|
||||
string customPath = Path.Combine(customDir, $"{language}_custom_words.bin");
|
||||
if (File.Exists(customPath))
|
||||
{
|
||||
File.Delete(customPath);
|
||||
Debug.Log($"[ModelController] Cleared custom words cache for language: {language}");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var customFiles = Directory.GetFiles(customDir, "*_custom_words.bin");
|
||||
foreach (var file in customFiles)
|
||||
{
|
||||
File.Delete(file);
|
||||
}
|
||||
Debug.Log($"[ModelController] Cleared all custom words cache files ({customFiles.Length} files)");
|
||||
}
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
Debug.LogError($"[ModelController] Error clearing custom words cache: {e.Message}");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public interface IModelController : IDisposable
|
||||
{
|
||||
bool IsModelLoaded(string language = null);
|
||||
void LoadModels();
|
||||
float[] GetWordVector(string word, string language = null);
|
||||
bool IsWordKnown(string word, string language = null);
|
||||
float GetCosineSimilarity(string word1, string word2, string language = null);
|
||||
List<string> GetRandomWords(int wordCount, string language = null);
|
||||
List<string> GetAllWords(string language = null);
|
||||
List<string> GetWordsFromSymbols(string inputSymbols, string language = null);
|
||||
List<string> GetWordsWithLength(int length, string language = null);
|
||||
IEnumerable<string> GetAvailableLanguages();
|
||||
float[] CreateInputTensor(string input, string language);
|
||||
bool AddWordAndSave(string newWord, string language = null);
|
||||
void ClearCustomWordsCache(string language = null);
|
||||
void SetBinaryFileProtection(bool protect);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user