mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 02:48:45 +08:00
add schuduler unittest and fix some bug
This commit is contained in:
parent
da4278779d
commit
925af7e1d2
@ -1,5 +1,6 @@
|
||||
from engine.retrieval import search_index
|
||||
from engine.ingestion import build_index
|
||||
from engine.ingestion import serialize
|
||||
|
||||
class Singleton(type):
|
||||
_instances = {}
|
||||
@ -11,9 +12,9 @@ class Singleton(type):
|
||||
|
||||
class Scheduler(metaclass=Singleton):
|
||||
def Search(self, index_file_key, vectors, k):
|
||||
assert index_file_key
|
||||
assert vectors
|
||||
assert k
|
||||
# assert index_file_key
|
||||
# assert vectors
|
||||
# assert k
|
||||
|
||||
return self.__scheduler(index_file_key, vectors, k)
|
||||
|
||||
@ -21,30 +22,29 @@ class Scheduler(metaclass=Singleton):
|
||||
def __scheduler(self, index_data_key, vectors, k):
|
||||
result_list = []
|
||||
|
||||
raw_data_list = index_data_key['raw']
|
||||
index_data_list = index_data_key['index']
|
||||
if 'raw' in index_data_key:
|
||||
raw_vectors = index_data_key['raw']
|
||||
d = index_data_key['dimension']
|
||||
|
||||
for key in raw_data_list:
|
||||
raw_data, d = self.GetRawData(key)
|
||||
if 'raw' in index_data_key:
|
||||
index_builder = build_index.FactoryIndex()
|
||||
index = index_builder().build(d, raw_data)
|
||||
searcher = search_index.FaissSearch(index) # silly
|
||||
index = index_builder().build(d, raw_vectors)
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
index_data_list = index_data_key['index']
|
||||
for key in index_data_list:
|
||||
index = self.GetIndexData(key)
|
||||
index = GetIndexData(key)
|
||||
searcher = search_index.FaissSearch(index)
|
||||
result_list.append(searcher.search_by_vectors(vectors, k))
|
||||
|
||||
if len(result_list) == 1:
|
||||
return result_list[0].vectors
|
||||
|
||||
result = search_index.top_k(sum(result_list), k)
|
||||
return result
|
||||
|
||||
# result = search_index.top_k(result_list, k)
|
||||
return result_list
|
||||
|
||||
|
||||
def GetIndexData(self, key):
|
||||
pass
|
||||
|
||||
def GetRawData(self, key):
|
||||
pass
|
||||
def GetIndexData(key):
|
||||
return serialize.read_index(key)
|
@ -1,3 +1,60 @@
|
||||
import unittest
|
||||
from ..scheduler import *
|
||||
|
||||
import unittest
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestScheduler(unittest.TestCase):
|
||||
def test_schedule(self):
|
||||
d = 64
|
||||
nb = 10000
|
||||
nq = 100
|
||||
nt = 5000
|
||||
xt, xb, xq = get_dataset(d, nb, nt, nq)
|
||||
file_name = "/tmp/faiss/tempfile_1"
|
||||
|
||||
|
||||
index = faiss.IndexFlatL2(d)
|
||||
print(index.is_trained)
|
||||
index.add(xb)
|
||||
faiss.write_index(index, file_name)
|
||||
Dref, Iref = index.search(xq, 5)
|
||||
|
||||
index2 = faiss.read_index(file_name)
|
||||
|
||||
schuduler_instance = Scheduler()
|
||||
|
||||
# query args 1
|
||||
query_index = dict()
|
||||
query_index['index'] = [file_name]
|
||||
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
|
||||
assert np.all(vectors == Iref)
|
||||
|
||||
# query args 2
|
||||
query_index = dict()
|
||||
query_index['raw'] = xt
|
||||
query_index['dimension'] = d
|
||||
query_index['index'] = [file_name]
|
||||
vectors = schuduler_instance.Search(query_index, vectors=xq, k=5)
|
||||
# print("success")
|
||||
|
||||
|
||||
def get_dataset(d, nb, nt, nq):
|
||||
"""A dataset that is not completely random but still challenging to
|
||||
index
|
||||
"""
|
||||
d1 = 10 # intrinsic dimension (more or less)
|
||||
n = nb + nt + nq
|
||||
rs = np.random.RandomState(1338)
|
||||
x = rs.normal(size=(n, d1))
|
||||
x = np.dot(x, rs.rand(d1, d))
|
||||
# now we have a d1-dim ellipsoid in d-dimensional space
|
||||
# higher factor (>4) -> higher frequency -> less linear
|
||||
x = x * (rs.rand(d) * 4 + 0.1)
|
||||
x = np.sin(x)
|
||||
x = x.astype('float32')
|
||||
return x[:nt], x[nt:-nq], x[-nq:]
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -8,6 +8,7 @@ from flask import jsonify
|
||||
from engine import db
|
||||
from engine.ingestion import build_index
|
||||
from engine.controller.scheduler import Scheduler
|
||||
from engine.ingestion import serialize
|
||||
import sys, os
|
||||
|
||||
class VectorEngine(object):
|
||||
@ -98,14 +99,15 @@ class VectorEngine(object):
|
||||
|
||||
# create index
|
||||
index_builder = build_index.FactoryIndex()
|
||||
index = index_builder().build(d, raw_data) # type: index
|
||||
index = build_index.Index.serialize(index) # type: array
|
||||
index = index_builder().build(d, raw_data)
|
||||
|
||||
# TODO(jinhai): store index into Cache
|
||||
index_filename = file.filename + '_index'
|
||||
serialize.write_index(file_name=index_filename, index=index)
|
||||
|
||||
# TODO(jinhai): Update raw_file_name => index_file_name
|
||||
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1, 'type': 'index'})
|
||||
FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
|
||||
'type': 'index',
|
||||
'filename': index_filename})
|
||||
pass
|
||||
|
||||
else:
|
||||
@ -135,13 +137,15 @@ class VectorEngine(object):
|
||||
if code == VectorEngine.FAULT_CODE:
|
||||
return VectorEngine.GROUP_NOT_EXIST
|
||||
|
||||
group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
|
||||
|
||||
# find all files
|
||||
files = FileTable.query.filter(FileTable.group_name == group_id).all()
|
||||
raw_keys = [ i.filename for i in files if i.type == 'raw' ]
|
||||
index_keys = [ i.filename for i in files if i.type == 'index' ]
|
||||
index_map = {}
|
||||
index_map['raw'] = raw_keys
|
||||
index_map['index'] = index_keys # {raw:[key1, key2], index:[key3, key4]}
|
||||
index_map['index'] = index_keys
|
||||
index_map['raw'] = GetVectorListFromRawFile(group_id)
|
||||
index_map['dimension'] = group.dimension
|
||||
|
||||
scheduler_instance = Scheduler()
|
||||
result = scheduler_instance.Search(index_map, vector, limit)
|
||||
|
7
pyengine/engine/ingestion/serialize.py
Normal file
7
pyengine/engine/ingestion/serialize.py
Normal file
@ -0,0 +1,7 @@
|
||||
import faiss
|
||||
|
||||
def write_index(index, file_name):
|
||||
faiss.write_index(index, file_name)
|
||||
|
||||
def read_index(file_name):
|
||||
return faiss.read_index(file_name)
|
@ -7,8 +7,9 @@ class SearchResult():
|
||||
self.vectors = I
|
||||
|
||||
def __add__(self, other):
|
||||
self.distance += other.distance
|
||||
self.vectors += other.vectors
|
||||
distance = self.distance + other.distance
|
||||
vectors = self.vectors + other.vectors
|
||||
return SearchResult(distance, vectors)
|
||||
|
||||
|
||||
class FaissSearch():
|
||||
@ -31,6 +32,7 @@ class FaissSearch():
|
||||
D, I = self.__index.search(vector_list, k)
|
||||
return SearchResult(D, I)
|
||||
|
||||
|
||||
import heapq
|
||||
def top_k(input, k):
|
||||
#sorted = heapq.nsmallest(k, input, key=input.key)
|
||||
pass
|
Loading…
Reference in New Issue
Block a user