File size: 6,755 Bytes
8ef2d83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
"""Tests for ARMS-HAT Python bindings."""
import pytest
import tempfile
import os
def test_import():
"""Test that the module can be imported."""
from arms_hat import HatIndex, HatConfig, SearchResult
def test_create_index():
"""Test index creation."""
from arms_hat import HatIndex
index = HatIndex.cosine(128)
assert len(index) == 0
assert index.is_empty()
def test_add_and_query():
"""Test adding points and querying."""
from arms_hat import HatIndex
dims = 64
index = HatIndex.cosine(dims)
# Add some points
ids = []
for i in range(10):
embedding = [0.0] * dims
embedding[i % dims] = 1.0
embedding[(i + 1) % dims] = 0.5
id_ = index.add(embedding)
ids.append(id_)
assert len(id_) == 32 # Hex ID
assert len(index) == 10
assert not index.is_empty()
# Query
query = [0.0] * dims
query[0] = 1.0
query[1] = 0.5
results = index.near(query, k=5)
assert len(results) == 5
# First result should be the closest match
assert results[0].id == ids[0]
assert results[0].score > 0.9 # High cosine similarity
def test_sessions():
"""Test session management."""
from arms_hat import HatIndex
index = HatIndex.cosine(32)
# Add points to first session
for i in range(5):
index.add([float(i % 32 == j) for j in range(32)])
# Start new session
index.new_session()
# Add points to second session
for i in range(5):
index.add([float((i + 10) % 32 == j) for j in range(32)])
stats = index.stats()
assert stats.session_count >= 1 # At least one session
assert stats.chunk_count == 10
def test_documents():
"""Test document management within sessions."""
from arms_hat import HatIndex
index = HatIndex.cosine(32)
# Add points to first document
for i in range(3):
index.add([1.0 if j == i else 0.0 for j in range(32)])
# Start new document
index.new_document()
# Add points to second document
for i in range(3):
index.add([1.0 if j == i + 10 else 0.0 for j in range(32)])
stats = index.stats()
assert stats.document_count >= 1
assert stats.chunk_count == 6
def test_persistence_bytes():
"""Test serialization to/from bytes."""
from arms_hat import HatIndex
dims = 64
index = HatIndex.cosine(dims)
# Add points
ids = []
for i in range(20):
embedding = [0.1] * dims
embedding[i % dims] = 1.0
ids.append(index.add(embedding))
# Serialize
data = index.to_bytes()
assert len(data) > 0
# Deserialize
loaded = HatIndex.from_bytes(data)
assert len(loaded) == len(index)
# Query should give same results
query = [0.1] * dims
query[0] = 1.0
original_results = index.near(query, k=5)
loaded_results = loaded.near(query, k=5)
assert len(original_results) == len(loaded_results)
assert original_results[0].id == loaded_results[0].id
def test_persistence_file():
"""Test save/load to file."""
from arms_hat import HatIndex
dims = 64
index = HatIndex.cosine(dims)
# Add points
for i in range(10):
embedding = [0.1] * dims
embedding[i % dims] = 1.0
index.add(embedding)
# Save to temp file
with tempfile.NamedTemporaryFile(suffix=".hat", delete=False) as f:
path = f.name
try:
index.save(path)
assert os.path.exists(path)
assert os.path.getsize(path) > 0
# Load
loaded = HatIndex.load(path)
assert len(loaded) == len(index)
finally:
os.unlink(path)
def test_config():
"""Test custom configuration."""
from arms_hat import HatIndex, HatConfig
config = HatConfig()
# Chain configuration
config = config.with_beam_width(5)
config = config.with_temporal_weight(0.1)
index = HatIndex.with_config(128, config)
assert len(index) == 0
def test_remove():
"""Test point removal."""
from arms_hat import HatIndex
index = HatIndex.cosine(32)
id1 = index.add([1.0] + [0.0] * 31)
id2 = index.add([0.0, 1.0] + [0.0] * 30)
assert len(index) == 2
index.remove(id1)
assert len(index) == 1
# Query should only find id2
results = index.near([0.0, 1.0] + [0.0] * 30, k=5)
assert len(results) == 1
assert results[0].id == id2
def test_consolidate():
"""Test consolidation."""
from arms_hat import HatIndex
index = HatIndex.cosine(32)
# Add many points
for i in range(100):
embedding = [0.0] * 32
embedding[i % 32] = 1.0
index.add(embedding)
# Consolidate should not error
index.consolidate()
index.consolidate_full()
assert len(index) == 100
def test_stats():
"""Test stats retrieval."""
from arms_hat import HatIndex
index = HatIndex.cosine(64)
for i in range(10):
index.add([float(i % 64 == j) for j in range(64)])
stats = index.stats()
assert stats.chunk_count == 10
assert stats.total_points == 10
def test_repr():
"""Test string representations."""
from arms_hat import HatIndex, HatConfig, SearchResult
index = HatIndex.cosine(64)
repr_str = repr(index)
assert "HatIndex" in repr_str
config = HatConfig()
repr_str = repr(config)
assert "HatConfig" in repr_str
def test_near_sessions():
"""Test coarse-grained session search."""
from arms_hat import HatIndex
index = HatIndex.cosine(32)
# Session 1: points along dimension 0
for i in range(5):
embedding = [0.0] * 32
embedding[0] = 1.0
embedding[i + 1] = 0.3
index.add(embedding)
index.new_session()
# Session 2: points along dimension 10
for i in range(5):
embedding = [0.0] * 32
embedding[10] = 1.0
embedding[i + 11] = 0.3
index.add(embedding)
# Query similar to session 1
query = [0.0] * 32
query[0] = 1.0
sessions = index.near_sessions(query, k=2)
assert len(sessions) >= 1
# First session should be more relevant
if len(sessions) > 1:
assert sessions[0].score >= sessions[1].score
def test_high_dimensions():
"""Test with OpenAI embedding dimensions."""
from arms_hat import HatIndex
dims = 1536 # OpenAI ada-002 dimensions
index = HatIndex.cosine(dims)
# Add some high-dimensional points
for i in range(10):
embedding = [(j * i * 0.01) % 1.0 for j in range(dims)]
index.add(embedding)
assert len(index) == 10
# Query
query = [0.5] * dims
results = index.near(query, k=5)
assert len(results) == 5
if __name__ == "__main__":
pytest.main([__file__, "-v"])
|