using System; using System.Collections.Generic; using System.IO; using System.Linq; using Newtonsoft.Json.Linq; using Unity.InferenceEngine; using UnityEngine; using System.Text; using System.Globalization; using WordsToolkit.Scripts.Levels; using WordsToolkit.Scripts.Services; using WordsToolkit.Scripts.Services.BannedWords; using VContainer; using Random = System.Random; namespace WordsToolkit.Scripts.NLP { public class ModelController : IModelController { private ILanguageService languageService; private LanguageConfiguration languageConfiguration; [SerializeField] Dictionary languageModels = new Dictionary(); private IBannedWordsService bannedWordsService; private Dictionary m_Workers = new Dictionary(); private Dictionary> wordToIndexByLanguage = new Dictionary>(); private Dictionary m_VectorDimensionByLanguage = new Dictionary(); private Dictionary m_VocabSizeByLanguage = new Dictionary(); private Dictionary> m_InputsByLanguage = new Dictionary>(); private string m_DefaultLanguage = "en"; public int VectorDimension = 100; // Protection flag to prevent accidental binary overwrite when you have custom words // NOTE: This is now mainly for the old SaveModelBinary method - new architecture uses custom words files private bool protectBinaryFile = false; /// /// Normalizes text by removing diacritics, accents, and converting to lowercase. /// This allows word matching to ignore emphasis marks. /// private string NormalizeText(string text) { if (string.IsNullOrEmpty(text)) return text; text = text.ToLower(); var normalizedString = text.Normalize(NormalizationForm.FormD); var stringBuilder = new StringBuilder(); foreach (var c in normalizedString) { var unicodeCategory = CharUnicodeInfo.GetUnicodeCategory(c); if (unicodeCategory != UnicodeCategory.NonSpacingMark) { stringBuilder.Append(c); } } return stringBuilder.ToString().Normalize(NormalizationForm.FormC); } public bool IsModelLoaded(string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); return m_Workers.ContainsKey(language) && wordToIndexByLanguage.ContainsKey(language) && wordToIndexByLanguage[language].Count > 0; } public ModelController(IBannedWordsService bannedWordsService, LanguageConfiguration languageConfiguration, ILanguageService languageService = null) { this.languageService = languageService; this.languageConfiguration = languageConfiguration; this.bannedWordsService = bannedWordsService; InitializeFromConfiguration(); LoadModels(); } public IEnumerable AvailableLanguages => m_Workers.Keys; private void InitializeFromConfiguration() { // Use LanguageService to get current language if available, fallback to configuration default if (languageService != null) { m_DefaultLanguage = languageService.GetCurrentLanguageCode(); } else { m_DefaultLanguage = languageConfiguration?.defaultLanguage ?? "en"; } languageModels.Clear(); foreach (var langInfo in languageConfiguration.languages) { if (!string.IsNullOrEmpty(langInfo.code) && langInfo.languageModel != null) { languageModels[langInfo.code] = langInfo.languageModel; } } } public void LoadModels() { InitializeFromConfiguration(); foreach (var languagePair in languageModels) { LoadModelBin(languagePair.Key, languagePair.Value); } } /// /// Sets whether to protect existing binary files from being overwritten. /// When true, LoadModel() won't overwrite existing .bin files that might contain custom words. /// public void SetBinaryFileProtection(bool protect) { protectBinaryFile = protect; Debug.Log($"[ModelController] Binary file protection {(protect ? "ENABLED" : "DISABLED")}"); } public void LoadModelBin(string language, ModelAsset modelAsset) { // Always load the model structure first var model = ModelLoader.Load(modelAsset); // (Re)create worker + input tensor if (m_Workers.ContainsKey(language)) { m_Workers[language].Dispose(); m_InputsByLanguage[language]?.Dispose(); } m_Workers[language] = new Worker(model, BackendType.CPU); m_InputsByLanguage[language] = new Tensor(new TensorShape(1)); // Always load base vocabulary from ONNX JSON (no caching) var worker = m_Workers[language]; var dummyInput = new Tensor(new TensorShape(1)); dummyInput[0] = 0; try { worker.Schedule(dummyInput); var vocabJsonTensor = worker.PeekOutput("wc_vocab_json"); if (vocabJsonTensor != null) { try { var byteTensor = vocabJsonTensor as Tensor; if (byteTensor != null) { var jsonData = new byte[byteTensor.shape.length]; var cpuTensor = byteTensor.ReadbackAndClone(); try { for (int i = 0; i < jsonData.Length; i++) jsonData[i] = cpuTensor[i]; string jsonString = Encoding.UTF8.GetString(jsonData); ParseAndLoadVocabulary(language, jsonString); } finally { cpuTensor.Dispose(); } } else if (vocabJsonTensor is Tensor intTensor) { var cpuTensor = intTensor.ReadbackAndClone(); try { var jsonBytes = new List(); for (int i = 0; i < cpuTensor.shape.length; i++) { int value = cpuTensor[i]; if (value == 0) break; if (value >= 0 && value <= 255) jsonBytes.Add((byte)value); } if (jsonBytes.Count > 0) { string jsonString = Encoding.UTF8.GetString(jsonBytes.ToArray()); ParseAndLoadVocabulary(language, jsonString); } } finally { cpuTensor.Dispose(); } } } finally { vocabJsonTensor.Dispose(); } } } catch (Exception e) { throw new Exception($"Error loading base vocabulary for {language}: {e.Message}", e); } finally { dummyInput.Dispose(); } // Now load any custom words from binary file LoadCustomWordsFromBinary(language); } /// /// Loads bytes from StreamingAssets using UnityWebRequest for Android compatibility /// private byte[] LoadStreamingAssetBytes(string path) { try { #if UNITY_ANDROID && !UNITY_EDITOR using var request = UnityEngine.Networking.UnityWebRequest.Get(path); var operation = request.SendWebRequest(); while (!operation.isDone) { } if (request.result == UnityEngine.Networking.UnityWebRequest.Result.Success) { return request.downloadHandler.data; } return null; #else if (File.Exists(path)) { return File.ReadAllBytes(path); } return null; #endif } catch (Exception e) { Debug.LogError($"[ModelController] Exception in LoadStreamingAssetBytes: {e.Message}"); return null; } } /// /// Loads custom words from binary file and adds them to the existing vocabulary. /// Binary file contains ONLY custom words, not the entire model cache. /// private void LoadCustomWordsFromBinary(string language) { string path = Path.Combine(Application.streamingAssetsPath, "WordConnectGameToolkit", "model", "custom", $"{language}_custom_words.bin"); if (!wordToIndexByLanguage.ContainsKey(language)) { return; } byte[] fileData = LoadStreamingAssetBytes(path); if (fileData == null) { return; } try { using var ms = new MemoryStream(fileData); using var br = new BinaryReader(ms, Encoding.UTF8); // Read header if (br.ReadInt32() != 0x43555354) // "CUST" magic number { return; } int baseVocabSize = br.ReadInt32(); // Original vocab size when custom words were added int customWordCount = br.ReadInt32(); // Number of custom words int vectorDim = br.ReadInt32(); // Vector dimension // Verify compatibility if (vectorDim != m_VectorDimensionByLanguage[language]) { return; } // Load custom words int currentVocabSize = wordToIndexByLanguage[language].Count; int nextIndex = currentVocabSize; for (int i = 0; i < customWordCount; i++) { string word = br.ReadString(); // Add to vocabulary wordToIndexByLanguage[language][word] = nextIndex++; } // Update vocabulary size m_VocabSizeByLanguage[language] = nextIndex; } catch (Exception e) { Debug.LogError($"[ModelController] Error loading custom words for '{language}': {e.Message}"); } } /// /// Saves only the custom words (not base vocabulary) to binary file. /// This creates a lightweight file with just the added words. /// private void SaveCustomWordsToBinary(string language) { if (!wordToIndexByLanguage.ContainsKey(language)) { Debug.LogError($"[ModelController] No vocabulary loaded for '{language}' - cannot save custom words"); return; } // Create StreamingAssets folder structure string streamingAssetsDir = Path.Combine(Application.dataPath, "StreamingAssets"); string modelDir = Path.Combine(streamingAssetsDir, "WordConnectGameToolkit", "model", "custom"); string path = Path.Combine(modelDir, $"{language}_custom_words.bin"); Directory.CreateDirectory(modelDir); try { // Get all custom words (assuming they have higher indices than base vocabulary) var vocab = wordToIndexByLanguage[language]; var baseVocabSize = GetEstimatedBaseVocabSize(language); var customWords = vocab.Where(kvp => kvp.Value >= baseVocabSize) .OrderBy(kvp => kvp.Value) .ToList(); if (customWords.Count == 0) { Debug.Log($"[ModelController] No custom words to save for '{language}'"); return; } using var fs = new FileStream(path, FileMode.Create, FileAccess.Write); using var bw = new BinaryWriter(fs, Encoding.UTF8); // Write header bw.Write(0x43555354); // "CUST" magic number bw.Write(baseVocabSize); // Original vocab size when custom words were added bw.Write(customWords.Count); // Number of custom words bw.Write(m_VectorDimensionByLanguage[language]); // Vector dimension // Write custom words (just the words, not embeddings since we can't access them easily) foreach (var kvp in customWords) { bw.Write(kvp.Key); // Word string } } catch (Exception e) { Debug.LogError($"[ModelController] Error saving custom words for '{language}': {e.Message}"); } } /// /// Estimates the base vocabulary size by looking at the original model. /// This is a heuristic - in a real implementation you might want to store this value. /// private int GetEstimatedBaseVocabSize(string language) { // Try to load the original model and get its vocabulary size if (languageModels.TryGetValue(language, out var modelAsset) && modelAsset != null) { try { var model = ModelLoader.Load(modelAsset); using var worker = new Worker(model, BackendType.CPU); using var dummyInput = new Tensor(new TensorShape(1)); dummyInput[0] = 0; worker.Schedule(dummyInput); var vocabJsonTensor = worker.PeekOutput("wc_vocab_json"); if (vocabJsonTensor != null) { // Parse JSON to get original vocab size // This is simplified - you'd need to extract and parse the JSON vocabJsonTensor.Dispose(); } } catch (Exception e) { Debug.LogWarning($"[ModelController] Could not determine base vocab size for '{language}': {e.Message}"); } } // Fallback heuristic - assume custom words start after a reasonable base size var currentSize = wordToIndexByLanguage[language].Count; return currentSize > 1000 ? currentSize - 100 : Math.Max(1, currentSize / 2); } private void ParseAndLoadVocabulary(string language, string jsonString) { try { var jsonObject = JObject.Parse(jsonString); var wordToIndex = new Dictionary(); var wordIndexDict = jsonObject["word_to_index"].ToObject>(); foreach (var pair in wordIndexDict) { string normalizedWord = NormalizeText(pair.Key); wordToIndex[normalizedWord] = pair.Value; } int vectorDimension = jsonObject["vector_size"].Value(); int vocabSize = jsonObject["vocab_size"].Value(); wordToIndexByLanguage[language] = wordToIndex; m_VectorDimensionByLanguage[language] = vectorDimension; m_VocabSizeByLanguage[language] = vocabSize; } catch (Exception e) { throw new Exception($"Error parsing vocabulary JSON for {language}: {e.Message}", e); } } public float[] GetWordVector(string word, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if (!IsModelLoaded(language)) { return null; } word = NormalizeText(word); if (!wordToIndexByLanguage[language].ContainsKey(word)) { return null; } m_InputsByLanguage[language][0] = wordToIndexByLanguage[language][word]; m_Workers[language].Schedule(m_InputsByLanguage[language]); var outputTensor = m_Workers[language].PeekOutput() as Tensor; if (outputTensor == null) { return null; } var cpuTensor = outputTensor.ReadbackAndClone(); try { float[] result = new float[m_VectorDimensionByLanguage[language]]; for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++) { result[i] = cpuTensor[i]; } return result; } finally { cpuTensor.Dispose(); outputTensor.Dispose(); } } public bool IsWordKnown(string word, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); string normalizedWord = NormalizeText(word); if (bannedWordsService.IsWordBanned(normalizedWord, language)) { return false; } if (!IsModelLoaded(language)) { return false; } float[] vector = GetWordVector(word, language); if (vector == null) { return false; } return true; } private bool IsZeroVector(float[] vector) { foreach (float value in vector) { if (!Mathf.Approximately(value, 0f)) return false; } return true; } public float GetCosineSimilarity(string word1, string word2, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if (!IsModelLoaded(language)) { return -1f; } word1 = NormalizeText(word1); word2 = NormalizeText(word2); float[] vector1 = GetWordVector(word1, language); float[] vector2 = GetWordVector(word2, language); if (vector1 == null || vector2 == null) return -1f; return CosineSimilarity(vector1, vector2); } public bool AddWordAndSave(string newWord, string language = null) { language ??= languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage; if (!IsModelLoaded(language)) { Debug.LogWarning($"[ModelController] AddWord failed – model for '{language}' not loaded."); return false; } newWord = NormalizeText(newWord); if (wordToIndexByLanguage[language].ContainsKey(newWord)) { Debug.LogWarning($"[ModelController] Word '{newWord}' already exists in vocab."); return false; } var newVector = new float[m_VectorDimensionByLanguage[language]]; int dim = m_VectorDimensionByLanguage[language]; if (newVector == null || newVector.Length != dim) { Debug.LogWarning($"[ModelController] Vector length mismatch (expected {dim})."); return false; } // 1️⃣ Update dictionaries int newIndex = m_VocabSizeByLanguage[language]; wordToIndexByLanguage[language][newWord] = newIndex; m_VocabSizeByLanguage[language] = newIndex + 1; // 2️⃣ Load the Model anew so we can mutate its constants safely if (!languageModels.TryGetValue(language, out var modelAsset) || modelAsset == null) { Debug.LogError($"[ModelController] Missing ModelAsset for language '{language}'."); return false; } var model = ModelLoader.Load(modelAsset); // We assume the first constant is the [vocab, dim] embedding matrix. // If your graph changes, adjust the index or name match here. if (model.constants == null || model.constants.Count == 0) { Debug.LogError("[ModelController] Model has no constants — cannot extend."); return false; } var embConst = model.constants[0]; if (embConst.dataType != DataType.Float || embConst.shape.rank != 2 || embConst.shape[1] != dim) { Debug.LogError("[ModelController] Unexpected embedding tensor layout."); return false; } int oldVocab = embConst.shape[0]; int oldElems = embConst.shape.length; // vocab * dim float[] oldBuf = new float[oldElems]; NativeTensorArray.Copy(embConst.weights, 0, oldBuf, 0, oldElems); // Build new buffer = old + newVector var newBuf = new float[oldElems + dim]; Buffer.BlockCopy(oldBuf, 0, newBuf, 0, oldElems * sizeof(float)); Buffer.BlockCopy(newVector,0, newBuf, oldElems * sizeof(float), dim * sizeof(float)); // Inference Engine requires a non‑generic NativeTensorArrayFromManagedArray // Inference Engine requires (Array, bytesPerElem, length, channels) // ctor args: (Array data, int srcElementOffset, int srcElementSize, int numDestElement) var newWeights = new NativeTensorArrayFromManagedArray( newBuf, // managed float[] 0, // start at element 0 sizeof(float), newBuf.Length); // total elements #pragma warning disable 618 // Constant.weights setter is obsolete but still functional embConst.weights = newWeights; #pragma warning restore 618 embConst.shape = new TensorShape(oldVocab + 1, dim); // update shape metadata // 3️⃣ Write new BIN SaveCustomWordsToBinary(language); Debug.Log($"[ModelController] Successfully added word '{newWord}' to {language} vocabulary at index {newIndex}"); return true; } private float CosineSimilarity(float[] v1, float[] v2) { float dotProduct = 0; float norm1 = 0; float norm2 = 0; for (int i = 0; i < v1.Length; i++) { dotProduct += v1[i] * v2[i]; norm1 += v1[i] * v1[i]; norm2 += v2[i] * v2[i]; } return dotProduct / (Mathf.Sqrt(norm1) * Mathf.Sqrt(norm2)); } private void OnDisable() { foreach (var worker in m_Workers.Values) { worker?.Dispose(); } foreach (var input in m_InputsByLanguage.Values) { input?.Dispose(); } m_Workers.Clear(); m_InputsByLanguage.Clear(); wordToIndexByLanguage.Clear(); m_VectorDimensionByLanguage.Clear(); m_VocabSizeByLanguage.Clear(); } public void Dispose() { OnDisable(); } public List GetRandomWords(int wordCount, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if (!IsModelLoaded(language)) { return new List(); } var result = new List(); var words = new List(wordToIndexByLanguage[language].Keys) .Where(word => !bannedWordsService.IsWordBanned(word, language)) .ToList(); var random = new Random(); wordCount = Mathf.Min(wordCount, words.Count); while (result.Count < wordCount) { int randomIndex = random.Next(words.Count); string word = words[randomIndex]; if (!result.Contains(word)) result.Add(word); } return result; } public List GetAllWords(string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if (!IsModelLoaded(language)) { return new List(); } return new List(wordToIndexByLanguage[language].Keys) .Where(word => !bannedWordsService.IsWordBanned(word, language)) .ToList(); } public List GetWordsFromSymbols(string inputSymbols, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if(!IsModelLoaded(language)) { return new List(); } if (string.IsNullOrEmpty(inputSymbols)) return new List(); inputSymbols = NormalizeText(inputSymbols); Dictionary charCounts = new Dictionary(); foreach (char c in inputSymbols) { if (charCounts.ContainsKey(c)) charCounts[c]++; else charCounts[c] = 1; } var candidateWords = new List(); foreach (var word in wordToIndexByLanguage[language].Keys) { if (bannedWordsService != null && bannedWordsService.IsWordBanned(word, language)) continue; Dictionary remainingCounts = new Dictionary(charCounts); bool isValid = true; foreach (char c in word) { if (!remainingCounts.ContainsKey(c) || remainingCounts[c] <= 0) { isValid = false; break; } remainingCounts[c]--; } if (isValid) { candidateWords.Add(word); } } if(!candidateWords.Contains(inputSymbols) && !bannedWordsService.IsWordBanned(inputSymbols, language) && IsWordKnown(inputSymbols, language)) { candidateWords.Add(inputSymbols); } float[] referenceVector = CreateInputTensor(inputSymbols, language); if (referenceVector != null && candidateWords.Count > 0) { return RankWordsBySimilarity(candidateWords, referenceVector, 200, language); } return candidateWords .OrderByDescending(w => w.Length) .ToList(); } public float[] CreateInputTensor(string inputSymbols, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if (!IsModelLoaded(language)) { return null; } if (string.IsNullOrEmpty(inputSymbols)) return null; inputSymbols = NormalizeText(inputSymbols); var symbolSet = new HashSet(inputSymbols); var bestMatches = wordToIndexByLanguage[language].Keys .Select(word => new { Word = word, SharedChars = word.Count(c => symbolSet.Contains(c)) }) .OrderByDescending(x => (float)x.SharedChars / x.Word.Length) .Take(5) .ToList(); if (bestMatches.Count == 0) return null; float[] compositeVector = new float[m_VectorDimensionByLanguage[language]]; foreach (var match in bestMatches) { float[] wordVector = GetWordVector(match.Word, language); if (wordVector != null) { for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++) { compositeVector[i] += wordVector[i]; } } } float magnitude = 0; for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++) { magnitude += compositeVector[i] * compositeVector[i]; } magnitude = Mathf.Sqrt(magnitude); if (magnitude > 0) { for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++) { compositeVector[i] /= magnitude; } } return compositeVector; } private List RankWordsBySimilarity(List candidateWords, float[] referenceVector, int maxResults, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); var scoredWords = new List<(string word, float score)>(); foreach (var word in candidateWords) { float[] wordVector = GetWordVector(word, language); if (wordVector != null) { float similarity = CosineSimilarity(referenceVector, wordVector); float adjustedScore = similarity * (0.8f + 0.2f * Mathf.Min(1f, word.Length / 10f)); scoredWords.Add((word, adjustedScore)); } } return scoredWords .OrderByDescending(pair => pair.score) .Take(maxResults) .Select(pair => pair.word) .ToList(); } /// /// Gets a list of words that are exactly the specified length. /// Returns complete words only, does not truncate or modify the words. /// /// The exact length of words to return /// Language code, defaults to current language if not specified /// List of complete words of the specified length public List GetWordsWithLength(int length, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); if (!IsModelLoaded(language)) { return new List(); } if (length <= 0) { Debug.LogWarning($"Invalid word length requested: {length}. Length must be greater than 0."); return new List(); } var random = new Random(); var matchingWords = wordToIndexByLanguage[language].Keys .Where(word => word != null && word.Length == length) .Where(word => bannedWordsService == null || !bannedWordsService.IsWordBanned(word, language)) .ToList(); if (matchingWords.Count == 0) { Debug.LogWarning($"No words found with exact length: {length}"); return new List(); } var sampleWords = matchingWords .OrderBy(word => random.Next()) .ToList(); if (sampleWords.Count == 0) { return new List(); } return sampleWords; } private List FindSimilarWords(float[] targetVector, int maxResults, string language = null) { language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage); var scoredWords = new List<(string word, float score)>(); foreach (var word in wordToIndexByLanguage[language].Keys) { if (bannedWordsService.IsWordBanned(word, language)) continue; float[] wordVector = GetWordVector(word, language); if (wordVector != null) { float similarity = CosineSimilarity(targetVector, wordVector); scoredWords.Add((word, similarity)); } } return scoredWords .OrderByDescending(pair => pair.score) .Take(maxResults) .Select(pair => pair.word) .ToList(); } public IEnumerable GetAvailableLanguages() { return AvailableLanguages; } /// /// Clears custom words cache for a specific language or all languages. /// This only removes custom word files, not the base model data. /// /// Language to clear, or null to clear all public void ClearCustomWordsCache(string language = null) { string customDir = Path.Combine(Application.dataPath, "StreamingAssets", "WordConnectGameToolkit", "model", "custom"); if (!Directory.Exists(customDir)) return; try { if (language != null) { string customPath = Path.Combine(customDir, $"{language}_custom_words.bin"); if (File.Exists(customPath)) { File.Delete(customPath); Debug.Log($"[ModelController] Cleared custom words cache for language: {language}"); } } else { var customFiles = Directory.GetFiles(customDir, "*_custom_words.bin"); foreach (var file in customFiles) { File.Delete(file); } Debug.Log($"[ModelController] Cleared all custom words cache files ({customFiles.Length} files)"); } } catch (Exception e) { Debug.LogError($"[ModelController] Error clearing custom words cache: {e.Message}"); } } } public interface IModelController : IDisposable { bool IsModelLoaded(string language = null); void LoadModels(); float[] GetWordVector(string word, string language = null); bool IsWordKnown(string word, string language = null); float GetCosineSimilarity(string word1, string word2, string language = null); List GetRandomWords(int wordCount, string language = null); List GetAllWords(string language = null); List GetWordsFromSymbols(string inputSymbols, string language = null); List GetWordsWithLength(int length, string language = null); IEnumerable GetAvailableLanguages(); float[] CreateInputTensor(string input, string language); bool AddWordAndSave(string newWord, string language = null); void ClearCustomWordsCache(string language = null); void SetBinaryFileProtection(bool protect); } }