Initial commit: Unity WordConnect project

This commit is contained in:
2025-08-01 19:12:05 +08:00
commit f14db75802
3503 changed files with 448337 additions and 0 deletions

View File

@ -0,0 +1,160 @@
using System;
using System.Collections.Generic;
using System.Linq;
using UnityEngine;
using WordsToolkit.Scripts.Enums;
using WordsToolkit.Scripts.System;
namespace WordsToolkit.Scripts.NLP
{
public interface ICustomWordRepository
{
void AddWord(string word, string language = null);
void InitWords(IEnumerable<string> words, string language = null);
bool ContainsWord(string word, string language = null);
void RemoveWord(string word, string language = null);
float[] GetWordVector(string word, string language = null);
bool AddExtraWord(string word);
int GetExtraWordsCount();
HashSet<string> GetExtraWords();
void ClearExtraWords();
}
public class CustomWordRepository : ICustomWordRepository
{
private readonly string m_DefaultLanguage = "en";
private readonly Dictionary<string, HashSet<string>> customWordsByLanguage = new Dictionary<string, HashSet<string>>();
private readonly Dictionary<string, Dictionary<string, float[]>> customWordVectorsByLanguage = new Dictionary<string, Dictionary<string, float[]>>();
private HashSet<string> extraWords = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
public void AddWord(string word, string language = null)
{
language = language ?? m_DefaultLanguage;
if (string.IsNullOrEmpty(word))
return;
word = word.ToLower();
if (!customWordsByLanguage.ContainsKey(language))
{
customWordsByLanguage[language] = new HashSet<string>();
}
customWordsByLanguage[language].Add(word);
if (!customWordVectorsByLanguage.ContainsKey(language))
{
customWordVectorsByLanguage[language] = new Dictionary<string, float[]>();
}
}
public void InitWords(IEnumerable<string> words, string language = null)
{
extraWords = LoadExtraWords();
foreach (var word in words)
{
AddWord(word, language);
}
}
public bool AddExtraWord(string word)
{
if (string.IsNullOrEmpty(word))
return false;
var addExtraWord = extraWords.Add(word.ToLower());
if(addExtraWord)
{
SaveExtraWords();
PlayerPrefs.SetInt("ExtraWordsCollected", PlayerPrefs.GetInt("ExtraWordsCollected") + 1);
}
return addExtraWord;
}
private void SaveExtraWords()
{
PlayerPrefs.SetString("ExtraWords", string.Join(",", extraWords));
PlayerPrefs.Save();
}
private HashSet<string> LoadExtraWords()
{
var extraWordsString = PlayerPrefs.GetString("ExtraWords", string.Empty);
if (string.IsNullOrEmpty(extraWordsString))
return new HashSet<string>(StringComparer.OrdinalIgnoreCase);
var wordsArray = extraWordsString.Split(new[] { ',' }, StringSplitOptions.RemoveEmptyEntries);
return new HashSet<string>(wordsArray, StringComparer.OrdinalIgnoreCase);
}
public int GetExtraWordsCount()
{
return extraWords.Count;
}
public HashSet<string> GetExtraWords()
{
return extraWords;
}
public void ClearExtraWords()
{
PlayerPrefs.DeleteKey("ExtraWords");
PlayerPrefs.Save();
extraWords.Clear();
EventManager.GetEvent<string>(EGameEvent.ExtraWordClaimed).Invoke(null);
}
public bool ContainsWord(string word, string language = null)
{
language = language ?? m_DefaultLanguage;
if (string.IsNullOrEmpty(word))
return false;
word = word.ToLower();
return customWordsByLanguage.ContainsKey(language) &&
customWordsByLanguage[language].Contains(word);
}
public void RemoveWord(string word, string language = null)
{
language = language ?? m_DefaultLanguage;
if (string.IsNullOrEmpty(word))
return;
word = word.ToLower();
if (customWordsByLanguage.ContainsKey(language))
{
customWordsByLanguage[language].Remove(word);
}
if (customWordVectorsByLanguage.ContainsKey(language))
{
customWordVectorsByLanguage[language].Remove(word);
}
}
public float[] GetWordVector(string word, string language = null)
{
language = language ?? m_DefaultLanguage;
if (string.IsNullOrEmpty(word))
return null;
word = word.ToLower();
if (customWordVectorsByLanguage.ContainsKey(language) &&
customWordVectorsByLanguage[language].ContainsKey(word))
{
return customWordVectorsByLanguage[language][word];
}
return null;
}
}
}

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 747954d67fcb749af9fdac963056383c
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,8 @@
fileFormatVersion: 2
guid: eeb3753e8be4047379e59e8ed411143c
folderAsset: yes
DefaultImporter:
externalObjects: {}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,179 @@
#nullable enable
#if UNITY_EDITOR
using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Text;
using UnityEditor;
using UnityEngine;
using Random = System.Random;
namespace WordsToolkit.Scripts.NLP.Editor
{
public class MakeEmbeddingModel : EditorWindow
{
private string selectedModelPath = "";
private string selectedTxtPath = "";
private Vector2 scrollPosition;
[MenuItem("WordConnect/NLP/Build Custom Model From TXT")]
static void ShowWindow()
{
var window = GetWindow<MakeEmbeddingModel>("Embedding Model Builder");
window.minSize = new Vector2(400, 300);
window.Show();
}
void OnGUI()
{
scrollPosition = EditorGUILayout.BeginScrollView(scrollPosition);
GUILayout.Label("Embedding Model Builder", EditorStyles.boldLabel);
GUILayout.Space(10);
// Model file selection
GUILayout.Label("Select existing model file to replace (optional):", EditorStyles.label);
EditorGUILayout.BeginHorizontal();
EditorGUILayout.TextField("Model File:", selectedModelPath);
if (GUILayout.Button("Browse", GUILayout.Width(80)))
{
var path = EditorUtility.OpenFilePanel("Select model file to replace", "Assets/WordConnectGameToolkit/model", "bin");
if (!string.IsNullOrEmpty(path))
{
selectedModelPath = path;
}
}
EditorGUILayout.EndHorizontal();
GUILayout.Space(10);
// Text file selection
GUILayout.Label("Select training text file:", EditorStyles.label);
EditorGUILayout.BeginHorizontal();
EditorGUILayout.TextField("Text File:", selectedTxtPath);
if (GUILayout.Button("Browse", GUILayout.Width(80)))
{
var path = EditorUtility.OpenFilePanel("Select word list text file", "", "txt");
if (!string.IsNullOrEmpty(path))
{
selectedTxtPath = path;
}
}
EditorGUILayout.EndHorizontal();
GUILayout.Space(20);
// Build button
UnityEngine.GUI.enabled = !string.IsNullOrEmpty(selectedTxtPath);
if (GUILayout.Button("Build Model", GUILayout.Height(30)))
{
BuildModel();
}
UnityEngine.GUI.enabled = true;
GUILayout.Space(10);
// Instructions
GUILayout.Label("Instructions:", EditorStyles.boldLabel);
GUILayout.Label("1. Optionally select an existing .bin model file to replace");
GUILayout.Label("2. Select a .txt file containing word list (one word per line)");
GUILayout.Label("3. Click 'Build Model' to generate random embeddings");
GUILayout.Label("4. The model will be saved to Assets/WordConnectGameToolkit/model/");
EditorGUILayout.EndScrollView();
}
private void BuildModel()
{
if (string.IsNullOrEmpty(selectedTxtPath))
{
EditorUtility.DisplayDialog("Error", "Please select a text file.", "OK");
return;
}
try
{
// ------------------------------------------------ 1⃣ read the TXT
var words = new List<string>();
var floats = new List<float>();
using var sr = new StreamReader(selectedTxtPath);
var inv = CultureInfo.InvariantCulture;
string? line;
while ((line = sr.ReadLine()) != null)
{
if (string.IsNullOrWhiteSpace(line) || line[0] == '#')
continue;
words.Add(line.Trim());
const int dd = 128;
var rng = new Random();
for (int i = 0; i < dd; ++i)
floats.Add((float)(rng.NextDouble() * 2.0 - 1.0)); // repeat per word
}
int vocab = words.Count;
int dim = floats.Count / vocab;
// ------------------------------------------------ 2⃣ Save as .bin file
string fileName;
if (!string.IsNullOrEmpty(selectedModelPath))
{
fileName = Path.GetFileName(selectedModelPath);
}
else
{
fileName = Path.GetFileNameWithoutExtension(selectedTxtPath) + "_model.bin";
}
string dirRel = "Assets/WordConnectGameToolkit/model";
string dirFull = Path.Combine(Application.dataPath, "WordConnectGameToolkit/model");
Directory.CreateDirectory(dirFull);
string binFullPath = Path.Combine(dirFull, fileName);
SaveModelBinary(binFullPath, words, floats.ToArray(), dim, vocab);
// Tell Unity there is a new asset
AssetDatabase.ImportAsset(Path.Combine(dirRel, fileName));
EditorUtility.DisplayDialog("Success",
$"Model saved to: {Path.Combine(dirRel, fileName)}\n" +
$"Vocabulary size: {vocab}\n" +
$"Vector dimension: {dim}", "OK");
Debug.Log($"✅ Model saved to: {Path.Combine(dirRel, fileName)}");
}
catch (Exception ex)
{
EditorUtility.DisplayDialog("Error", $"Failed to build model: {ex.Message}", "OK");
Debug.LogError($"Failed to build model: {ex}");
}
}
private static void SaveModelBinary(string filePath, List<string> words, float[] embeddings, int vectorDim, int vocabSize)
{
using var bw = new BinaryWriter(File.Create(filePath));
// Magic header
bw.Write(0x564F4342); // "BOCV"
// Metadata
bw.Write(vocabSize);
bw.Write(vectorDim);
// Write vocabulary
foreach (var word in words)
{
bw.Write(word);
}
// Write embeddings
foreach (var value in embeddings)
{
bw.Write(value);
}
}
}
}
#endif

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: f5e839a21e6f4a37b12087d508dad969
timeCreated: 1750749286

View File

@ -0,0 +1,11 @@
fileFormatVersion: 2
guid: 348ebc1721c3b41b2bd58e1a04c30cb2
MonoImporter:
externalObjects: {}
serializedVersion: 2
defaultReferences: []
executionOrder: 0
icon: {instanceID: 0}
userData:
assetBundleName:
assetBundleVariant:

View File

@ -0,0 +1,908 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Newtonsoft.Json.Linq;
using Unity.Sentis;
using UnityEngine;
using System.Text;
using WordsToolkit.Scripts.Levels;
using WordsToolkit.Scripts.Services;
using WordsToolkit.Scripts.Services.BannedWords;
using VContainer;
using Random = System.Random;
namespace WordsToolkit.Scripts.NLP
{
public class ModelController : IModelController
{
private ILanguageService languageService;
private LanguageConfiguration languageConfiguration;
[SerializeField]
Dictionary<string, ModelAsset> languageModels = new Dictionary<string, ModelAsset>();
private IBannedWordsService bannedWordsService;
private Dictionary<string, Worker> m_Workers = new Dictionary<string, Worker>();
private Dictionary<string, Dictionary<string, int>> wordToIndexByLanguage = new Dictionary<string, Dictionary<string, int>>();
private Dictionary<string, int> m_VectorDimensionByLanguage = new Dictionary<string, int>();
private Dictionary<string, int> m_VocabSizeByLanguage = new Dictionary<string, int>();
private Dictionary<string, Tensor<int>> m_InputsByLanguage = new Dictionary<string, Tensor<int>>();
private string m_DefaultLanguage = "en";
public int VectorDimension = 100;
// Protection flag to prevent accidental binary overwrite when you have custom words
// NOTE: This is now mainly for the old SaveModelBinary method - new architecture uses custom words files
private bool protectBinaryFile = false;
public bool IsModelLoaded(string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
return m_Workers.ContainsKey(language) &&
wordToIndexByLanguage.ContainsKey(language) &&
wordToIndexByLanguage[language].Count > 0;
}
public ModelController(IBannedWordsService bannedWordsService,
LanguageConfiguration languageConfiguration,
ILanguageService languageService = null)
{
this.languageService = languageService;
this.languageConfiguration = languageConfiguration;
this.bannedWordsService = bannedWordsService;
InitializeFromConfiguration();
LoadModels();
}
public IEnumerable<string> AvailableLanguages => m_Workers.Keys;
private void InitializeFromConfiguration()
{
// Use LanguageService to get current language if available, fallback to configuration default
if (languageService != null)
{
m_DefaultLanguage = languageService.GetCurrentLanguageCode();
}
else
{
m_DefaultLanguage = languageConfiguration?.defaultLanguage ?? "en";
}
languageModels.Clear();
foreach (var langInfo in languageConfiguration.languages)
{
if (!string.IsNullOrEmpty(langInfo.code) && langInfo.languageModel != null)
{
languageModels[langInfo.code] = langInfo.languageModel;
}
}
}
public void LoadModels()
{
foreach (var languagePair in languageModels)
{
LoadModelBin(languagePair.Key, languagePair.Value);
}
}
/// <summary>
/// Sets whether to protect existing binary files from being overwritten.
/// When true, LoadModel() won't overwrite existing .bin files that might contain custom words.
/// </summary>
public void SetBinaryFileProtection(bool protect)
{
protectBinaryFile = protect;
Debug.Log($"[ModelController] Binary file protection {(protect ? "ENABLED" : "DISABLED")}");
}
public void LoadModelBin(string language, ModelAsset modelAsset)
{
// Always load the model structure first
var model = ModelLoader.Load(modelAsset);
// (Re)create worker + input tensor
if (m_Workers.ContainsKey(language))
{
m_Workers[language].Dispose();
m_InputsByLanguage[language]?.Dispose();
}
m_Workers[language] = new Worker(model, BackendType.CPU);
m_InputsByLanguage[language] = new Tensor<int>(new TensorShape(1));
// Always load base vocabulary from ONNX JSON (no caching)
var worker = m_Workers[language];
var dummyInput = new Tensor<int>(new TensorShape(1));
dummyInput[0] = 0;
try
{
worker.Schedule(dummyInput);
var vocabJsonTensor = worker.PeekOutput("wc_vocab_json");
if (vocabJsonTensor != null)
{
try
{
var byteTensor = vocabJsonTensor as Tensor<byte>;
if (byteTensor != null)
{
var jsonData = new byte[byteTensor.shape.length];
var cpuTensor = byteTensor.ReadbackAndClone();
try
{
for (int i = 0; i < jsonData.Length; i++)
jsonData[i] = cpuTensor[i];
string jsonString = Encoding.UTF8.GetString(jsonData);
ParseAndLoadVocabulary(language, jsonString);
}
finally
{
cpuTensor.Dispose();
}
}
else if (vocabJsonTensor is Tensor<int> intTensor)
{
var cpuTensor = intTensor.ReadbackAndClone();
try
{
var jsonBytes = new List<byte>();
for (int i = 0; i < cpuTensor.shape.length; i++)
{
int value = cpuTensor[i];
if (value == 0)
break;
if (value >= 0 && value <= 255)
jsonBytes.Add((byte)value);
}
if (jsonBytes.Count > 0)
{
string jsonString = Encoding.UTF8.GetString(jsonBytes.ToArray());
ParseAndLoadVocabulary(language, jsonString);
}
}
finally
{
cpuTensor.Dispose();
}
}
}
finally
{
vocabJsonTensor.Dispose();
}
}
}
catch (Exception e)
{
throw new Exception($"Error loading base vocabulary for {language}: {e.Message}", e);
}
finally
{
dummyInput.Dispose();
}
// Now load any custom words from binary file
LoadCustomWordsFromBinary(language);
}
/// <summary>
/// Loads custom words from binary file and adds them to the existing vocabulary.
/// Binary file contains ONLY custom words, not the entire model cache.
/// </summary>
private void LoadCustomWordsFromBinary(string language)
{
string path = Path.Combine(Application.dataPath, "WordsToolkit", "model",
"custom", $"{language}_custom_words.bin");
if (!File.Exists(path))
{
return;
}
if (!wordToIndexByLanguage.ContainsKey(language))
{
return;
}
try
{
using var fs = new FileStream(path, FileMode.Open, FileAccess.Read);
using var br = new BinaryReader(fs, Encoding.UTF8);
// Read header
if (br.ReadInt32() != 0x43555354) // "CUST" magic number
{
return;
}
int baseVocabSize = br.ReadInt32(); // Original vocab size when custom words were added
int customWordCount = br.ReadInt32(); // Number of custom words
int vectorDim = br.ReadInt32(); // Vector dimension
// Verify compatibility
if (vectorDim != m_VectorDimensionByLanguage[language])
{
return;
}
// Load custom words
int currentVocabSize = wordToIndexByLanguage[language].Count;
int nextIndex = currentVocabSize;
for (int i = 0; i < customWordCount; i++)
{
string word = br.ReadString();
// Add to vocabulary
wordToIndexByLanguage[language][word] = nextIndex++;
}
// Update vocabulary size
m_VocabSizeByLanguage[language] = nextIndex;
}
catch (Exception e)
{
Debug.LogError($"[ModelController] Error loading custom words for '{language}': {e.Message}");
}
}
/// <summary>
/// Saves only the custom words (not base vocabulary) to binary file.
/// This creates a lightweight file with just the added words.
/// </summary>
private void SaveCustomWordsToBinary(string language)
{
if (!wordToIndexByLanguage.ContainsKey(language))
{
Debug.LogError($"[ModelController] No vocabulary loaded for '{language}' - cannot save custom words");
return;
}
string dir = Path.Combine(Application.dataPath, "WordsToolkit", "model", "custom");
string path = Path.Combine(dir, $"{language}_custom_words.bin");
Directory.CreateDirectory(dir);
try
{
// Get all custom words (assuming they have higher indices than base vocabulary)
var vocab = wordToIndexByLanguage[language];
var baseVocabSize = GetEstimatedBaseVocabSize(language);
var customWords = vocab.Where(kvp => kvp.Value >= baseVocabSize)
.OrderBy(kvp => kvp.Value)
.ToList();
if (customWords.Count == 0)
{
Debug.Log($"[ModelController] No custom words to save for '{language}'");
return;
}
using var fs = new FileStream(path, FileMode.Create, FileAccess.Write);
using var bw = new BinaryWriter(fs, Encoding.UTF8);
// Write header
bw.Write(0x43555354); // "CUST" magic number
bw.Write(baseVocabSize); // Original vocab size when custom words were added
bw.Write(customWords.Count); // Number of custom words
bw.Write(m_VectorDimensionByLanguage[language]); // Vector dimension
// Write custom words (just the words, not embeddings since we can't access them easily)
foreach (var kvp in customWords)
{
bw.Write(kvp.Key); // Word string
}
}
catch (Exception e)
{
Debug.LogError($"[ModelController] Error saving custom words for '{language}': {e.Message}");
}
}
/// <summary>
/// Estimates the base vocabulary size by looking at the original model.
/// This is a heuristic - in a real implementation you might want to store this value.
/// </summary>
private int GetEstimatedBaseVocabSize(string language)
{
// Try to load the original model and get its vocabulary size
if (languageModels.TryGetValue(language, out var modelAsset) && modelAsset != null)
{
try
{
var model = ModelLoader.Load(modelAsset);
using var worker = new Worker(model, BackendType.CPU);
using var dummyInput = new Tensor<int>(new TensorShape(1));
dummyInput[0] = 0;
worker.Schedule(dummyInput);
var vocabJsonTensor = worker.PeekOutput("wc_vocab_json");
if (vocabJsonTensor != null)
{
// Parse JSON to get original vocab size
// This is simplified - you'd need to extract and parse the JSON
vocabJsonTensor.Dispose();
}
}
catch (Exception e)
{
Debug.LogWarning($"[ModelController] Could not determine base vocab size for '{language}': {e.Message}");
}
}
// Fallback heuristic - assume custom words start after a reasonable base size
var currentSize = wordToIndexByLanguage[language].Count;
return currentSize > 1000 ? currentSize - 100 : Math.Max(1, currentSize / 2);
}
private void ParseAndLoadVocabulary(string language, string jsonString)
{
try
{
var jsonObject = JObject.Parse(jsonString);
var wordToIndex = new Dictionary<string, int>();
var wordIndexDict = jsonObject["word_to_index"].ToObject<Dictionary<string, int>>();
foreach (var pair in wordIndexDict)
{
wordToIndex[pair.Key] = pair.Value;
}
int vectorDimension = jsonObject["vector_size"].Value<int>();
int vocabSize = jsonObject["vocab_size"].Value<int>();
wordToIndexByLanguage[language] = wordToIndex;
m_VectorDimensionByLanguage[language] = vectorDimension;
m_VocabSizeByLanguage[language] = vocabSize;
}
catch (Exception e)
{
throw new Exception($"Error parsing vocabulary JSON for {language}: {e.Message}", e);
}
}
public float[] GetWordVector(string word, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (!IsModelLoaded(language))
{
return null;
}
if (!wordToIndexByLanguage[language].ContainsKey(word))
{
return null;
}
m_InputsByLanguage[language][0] = wordToIndexByLanguage[language][word];
m_Workers[language].Schedule(m_InputsByLanguage[language]);
var outputTensor = m_Workers[language].PeekOutput() as Tensor<float>;
if (outputTensor == null)
{
return null;
}
var cpuTensor = outputTensor.ReadbackAndClone();
try
{
float[] result = new float[m_VectorDimensionByLanguage[language]];
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
{
result[i] = cpuTensor[i];
}
return result;
}
finally
{
cpuTensor.Dispose();
outputTensor.Dispose();
}
}
public bool IsWordKnown(string word, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (bannedWordsService.IsWordBanned(word, language))
{
return false;
}
if (!IsModelLoaded(language))
{
return false;
}
float[] vector = GetWordVector(word, language);
if (vector == null)
{
return false;
}
return true;
}
private bool IsZeroVector(float[] vector)
{
foreach (float value in vector)
{
if (!Mathf.Approximately(value, 0f))
return false;
}
return true;
}
public float GetCosineSimilarity(string word1, string word2, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (!IsModelLoaded(language))
{
return -1f;
}
float[] vector1 = GetWordVector(word1, language);
float[] vector2 = GetWordVector(word2, language);
if (vector1 == null || vector2 == null)
return -1f;
return CosineSimilarity(vector1, vector2);
}
public bool AddWordAndSave(string newWord, string language = null)
{
language ??= languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage;
if (!IsModelLoaded(language))
{
Debug.LogWarning($"[ModelController] AddWord failed  model for '{language}' not loaded.");
return false;
}
if (wordToIndexByLanguage[language].ContainsKey(newWord))
{
Debug.LogWarning($"[ModelController] Word '{newWord}' already exists in vocab.");
return false;
}
var newVector = new float[m_VectorDimensionByLanguage[language]];
int dim = m_VectorDimensionByLanguage[language];
if (newVector == null || newVector.Length != dim)
{
Debug.LogWarning($"[ModelController] Vector length mismatch (expected {dim}).");
return false;
}
// 1 Update dictionaries
int newIndex = m_VocabSizeByLanguage[language];
wordToIndexByLanguage[language][newWord] = newIndex;
m_VocabSizeByLanguage[language] = newIndex + 1;
// 2 Load the Model anew so we can mutate its constants safely
if (!languageModels.TryGetValue(language, out var modelAsset) || modelAsset == null)
{
Debug.LogError($"[ModelController] Missing ModelAsset for language '{language}'.");
return false;
}
var model = ModelLoader.Load(modelAsset);
// We assume the <b>first</b> constant is the [vocab, dim] embedding matrix.
// If your graph changes, adjust the index or name match here.
if (model.constants == null || model.constants.Count == 0)
{
Debug.LogError("[ModelController] Model has no constants — cannot extend.");
return false;
}
var embConst = model.constants[0];
if (embConst.dataType != DataType.Float || embConst.shape.rank != 2 || embConst.shape[1] != dim)
{
Debug.LogError("[ModelController] Unexpected embedding tensor layout.");
return false;
}
int oldVocab = embConst.shape[0];
int oldElems = embConst.shape.length; // vocab * dim
float[] oldBuf = new float[oldElems];
NativeTensorArray.Copy(embConst.weights, 0, oldBuf, 0, oldElems);
// Build new buffer = old + newVector
var newBuf = new float[oldElems + dim];
Buffer.BlockCopy(oldBuf, 0, newBuf, 0, oldElems * sizeof(float));
Buffer.BlockCopy(newVector,0, newBuf, oldElems * sizeof(float), dim * sizeof(float));
// Sentis requires a nongeneric NativeTensorArrayFromManagedArray
// Sentis requires (Array, bytesPerElem, length, channels)
// ctor args: (Array data, int srcElementOffset, int srcElementSize, int numDestElement)
var newWeights = new NativeTensorArrayFromManagedArray(
newBuf, // managed float[]
0, // start at element 0
sizeof(float),
newBuf.Length); // total elements
#pragma warning disable 618 // Constant.weights setter is obsolete but still functional
embConst.weights = newWeights;
#pragma warning restore 618
embConst.shape = new TensorShape(oldVocab + 1, dim); // update shape metadata
// 3 Write new BIN
SaveCustomWordsToBinary(language);
Debug.Log($"[ModelController] Successfully added word '{newWord}' to {language} vocabulary at index {newIndex}");
return true;
}
private float CosineSimilarity(float[] v1, float[] v2)
{
float dotProduct = 0;
float norm1 = 0;
float norm2 = 0;
for (int i = 0; i < v1.Length; i++)
{
dotProduct += v1[i] * v2[i];
norm1 += v1[i] * v1[i];
norm2 += v2[i] * v2[i];
}
return dotProduct / (Mathf.Sqrt(norm1) * Mathf.Sqrt(norm2));
}
private void OnDisable()
{
foreach (var worker in m_Workers.Values)
{
worker?.Dispose();
}
foreach (var input in m_InputsByLanguage.Values)
{
input?.Dispose();
}
m_Workers.Clear();
m_InputsByLanguage.Clear();
wordToIndexByLanguage.Clear();
m_VectorDimensionByLanguage.Clear();
m_VocabSizeByLanguage.Clear();
}
public void Dispose()
{
OnDisable();
}
public List<string> GetRandomWords(int wordCount, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (!IsModelLoaded(language))
{
return new List<string>();
}
var result = new List<string>();
var words = new List<string>(wordToIndexByLanguage[language].Keys)
.Where(word => !bannedWordsService.IsWordBanned(word, language))
.ToList();
var random = new Random();
wordCount = Mathf.Min(wordCount, words.Count);
while (result.Count < wordCount)
{
int randomIndex = random.Next(words.Count);
string word = words[randomIndex];
if (!result.Contains(word))
result.Add(word);
}
return result;
}
public List<string> GetAllWords(string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (!IsModelLoaded(language))
{
return new List<string>();
}
return new List<string>(wordToIndexByLanguage[language].Keys)
.Where(word => !bannedWordsService.IsWordBanned(word, language))
.ToList();
}
public List<string> GetWordsFromSymbols(string inputSymbols, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if(!IsModelLoaded(language))
{
return new List<string>();
}
if (string.IsNullOrEmpty(inputSymbols))
return new List<string>();
inputSymbols = inputSymbols.ToLower();
Dictionary<char, int> charCounts = new Dictionary<char, int>();
foreach (char c in inputSymbols)
{
if (charCounts.ContainsKey(c))
charCounts[c]++;
else
charCounts[c] = 1;
}
var candidateWords = new List<string>();
foreach (var word in wordToIndexByLanguage[language].Keys)
{
if (bannedWordsService != null && bannedWordsService.IsWordBanned(word, language))
continue;
Dictionary<char, int> remainingCounts = new Dictionary<char, int>(charCounts);
bool isValid = true;
foreach (char c in word)
{
if (!remainingCounts.ContainsKey(c) || remainingCounts[c] <= 0)
{
isValid = false;
break;
}
remainingCounts[c]--;
}
if (isValid)
{
candidateWords.Add(word);
}
}
if(!candidateWords.Contains(inputSymbols) &&
!bannedWordsService.IsWordBanned(inputSymbols, language) && IsWordKnown(inputSymbols, language))
{
candidateWords.Add(inputSymbols);
}
float[] referenceVector = CreateInputTensor(inputSymbols, language);
if (referenceVector != null && candidateWords.Count > 0)
{
return RankWordsBySimilarity(candidateWords, referenceVector, 200, language);
}
return candidateWords
.OrderByDescending(w => w.Length)
.ToList();
}
public float[] CreateInputTensor(string inputSymbols, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (!IsModelLoaded(language))
{
return null;
}
if (string.IsNullOrEmpty(inputSymbols))
return null;
var symbolSet = new HashSet<char>(inputSymbols.ToLower());
var bestMatches = wordToIndexByLanguage[language].Keys
.Select(word => new {
Word = word,
SharedChars = word.Count(c => symbolSet.Contains(c))
})
.OrderByDescending(x => (float)x.SharedChars / x.Word.Length)
.Take(5)
.ToList();
if (bestMatches.Count == 0)
return null;
float[] compositeVector = new float[m_VectorDimensionByLanguage[language]];
foreach (var match in bestMatches)
{
float[] wordVector = GetWordVector(match.Word, language);
if (wordVector != null)
{
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
{
compositeVector[i] += wordVector[i];
}
}
}
float magnitude = 0;
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
{
magnitude += compositeVector[i] * compositeVector[i];
}
magnitude = Mathf.Sqrt(magnitude);
if (magnitude > 0)
{
for (int i = 0; i < m_VectorDimensionByLanguage[language]; i++)
{
compositeVector[i] /= magnitude;
}
}
return compositeVector;
}
private List<string> RankWordsBySimilarity(List<string> candidateWords, float[] referenceVector, int maxResults, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
var scoredWords = new List<(string word, float score)>();
foreach (var word in candidateWords)
{
float[] wordVector = GetWordVector(word, language);
if (wordVector != null)
{
float similarity = CosineSimilarity(referenceVector, wordVector);
float adjustedScore = similarity * (0.8f + 0.2f * Mathf.Min(1f, word.Length / 10f));
scoredWords.Add((word, adjustedScore));
}
}
return scoredWords
.OrderByDescending(pair => pair.score)
.Take(maxResults)
.Select(pair => pair.word)
.ToList();
}
/// <summary>
/// Gets a list of words that are exactly the specified length.
/// Returns complete words only, does not truncate or modify the words.
/// </summary>
/// <param name="length">The exact length of words to return</param>
/// <param name="language">Language code, defaults to current language if not specified</param>
/// <returns>List of complete words of the specified length</returns>
public List<string> GetWordsWithLength(int length, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
if (!IsModelLoaded(language))
{
return new List<string>();
}
if (length <= 0)
{
Debug.LogWarning($"Invalid word length requested: {length}. Length must be greater than 0.");
return new List<string>();
}
var random = new Random();
var matchingWords = wordToIndexByLanguage[language].Keys
.Where(word => word != null && word.Length == length)
.Where(word => bannedWordsService == null || !bannedWordsService.IsWordBanned(word, language))
.ToList();
if (matchingWords.Count == 0)
{
Debug.LogWarning($"No words found with exact length: {length}");
return new List<string>();
}
var sampleWords = matchingWords
.OrderBy(word => random.Next())
.ToList();
if (sampleWords.Count == 0)
{
return new List<string>();
}
return sampleWords;
}
private List<string> FindSimilarWords(float[] targetVector, int maxResults, string language = null)
{
language = language ?? (languageService?.GetCurrentLanguageCode() ?? m_DefaultLanguage);
var scoredWords = new List<(string word, float score)>();
foreach (var word in wordToIndexByLanguage[language].Keys)
{
if (bannedWordsService.IsWordBanned(word, language))
continue;
float[] wordVector = GetWordVector(word, language);
if (wordVector != null)
{
float similarity = CosineSimilarity(targetVector, wordVector);
scoredWords.Add((word, similarity));
}
}
return scoredWords
.OrderByDescending(pair => pair.score)
.Take(maxResults)
.Select(pair => pair.word)
.ToList();
}
public IEnumerable<string> GetAvailableLanguages()
{
return AvailableLanguages;
}
/// <summary>
/// Clears custom words cache for a specific language or all languages.
/// This only removes custom word files, not the base model data.
/// </summary>
/// <param name="language">Language to clear, or null to clear all</param>
public void ClearCustomWordsCache(string language = null)
{
string customDir = Path.Combine(Application.dataPath, "WordsToolkit", "model", "custom");
if (!Directory.Exists(customDir))
return;
try
{
if (language != null)
{
string customPath = Path.Combine(customDir, $"{language}_custom_words.bin");
if (File.Exists(customPath))
{
File.Delete(customPath);
Debug.Log($"[ModelController] Cleared custom words cache for language: {language}");
}
}
else
{
var customFiles = Directory.GetFiles(customDir, "*_custom_words.bin");
foreach (var file in customFiles)
{
File.Delete(file);
}
Debug.Log($"[ModelController] Cleared all custom words cache files ({customFiles.Length} files)");
}
}
catch (Exception e)
{
Debug.LogError($"[ModelController] Error clearing custom words cache: {e.Message}");
}
}
}
public interface IModelController : IDisposable
{
bool IsModelLoaded(string language = null);
void LoadModels();
float[] GetWordVector(string word, string language = null);
bool IsWordKnown(string word, string language = null);
float GetCosineSimilarity(string word1, string word2, string language = null);
List<string> GetRandomWords(int wordCount, string language = null);
List<string> GetAllWords(string language = null);
List<string> GetWordsFromSymbols(string inputSymbols, string language = null);
List<string> GetWordsWithLength(int length, string language = null);
IEnumerable<string> GetAvailableLanguages();
float[] CreateInputTensor(string input, string language);
bool AddWordAndSave(string newWord, string language = null);
void ClearCustomWordsCache(string language = null);
void SetBinaryFileProtection(bool protect);
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 29670608df9b4ec78c2dda672314e9f5
timeCreated: 1740286753

View File

@ -0,0 +1,58 @@
using UnityEngine;
using UnityEngine.Serialization;
namespace WordsToolkit.Scripts.NLP
{
public class WordEmbeddingTest : MonoBehaviour
{
[FormerlySerializedAs("wordModel")]
[SerializeField]
private ModelController wordModelController;
private string[] testWords = new string[]
{
"hello",
"world",
"xyz123", // unknown word
"computer",
"asdfghjkl", // unknown word
"programming",
"@#$%",
"автор",
"программирование",
"привет",
"мир",
};
void Start()
{
if (wordModelController == null)
{
wordModelController = GetComponent<ModelController>();
}
TestWordRecognition();
}
void TestWordRecognition()
{
Debug.Log("=== Testing Word Recognition ===");
foreach (var word in testWords)
{
float[] vector = wordModelController.GetWordVector(word);
bool isKnown = vector != null && !IsZeroVector(vector);
Debug.Log($"Word: '{word}' - {(isKnown ? " Known" : " Unknown")}");
}
}
private bool IsZeroVector(float[] vector)
{
foreach (float value in vector)
{
if (!Mathf.Approximately(value, 0f))
return false;
}
return true;
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 9829e5579abc44518f6366a6f98da5b4
timeCreated: 1740290814

View File

@ -0,0 +1,49 @@
using TMPro;
using UnityEngine;
using UnityEngine.Serialization;
using UnityEngine.UI;
namespace WordsToolkit.Scripts.NLP
{
public class WordEmbeddingTestUI : MonoBehaviour
{
[FormerlySerializedAs("wordModel")]
[SerializeField] private ModelController wordModelController;
[SerializeField] private TMP_InputField inputField;
[SerializeField] private Button testButton;
private Color validColor = new Color(0.7f, 1f, 0.7f); // Light green
private Color invalidColor = new Color(1f, 0.7f, 0.7f); // Light red
private Color defaultColor = Color.white;
void Start()
{
testButton.onClick.AddListener(TestInputWord);
inputField.onEndEdit.AddListener(delegate { TestInputWord(); });
}
public void TestInputWord()
{
string word = inputField.text.Trim().ToLower();
if (string.IsNullOrEmpty(word))
{
inputField.GetComponent<Image>().color = defaultColor;
return;
}
float[] vector = wordModelController.GetWordVector(word);
bool isKnown = vector != null && !IsZeroVector(vector);
inputField.GetComponent<Image>().color = isKnown ? validColor : invalidColor;
}
private bool IsZeroVector(float[] vector)
{
foreach (float value in vector)
{
if (!Mathf.Approximately(value, 0f))
return false;
}
return true;
}
}
}

View File

@ -0,0 +1,3 @@
fileFormatVersion: 2
guid: 7d126f4abbf445f6bbb79d50d52f0019
timeCreated: 1740290961