finish results reduce in mishards

This commit is contained in:
yhz 2019-11-19 20:36:08 +08:00
parent 83d9bf6966
commit 2f8be3d058

View File

@ -35,12 +35,13 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
if diss[k - 1] <= source_diss[0]: if diss[k - 1] <= source_diss[0]:
return ids, diss return ids, diss
diss_t = enumerate(source_diss.extend(diss)) source_diss.extend(diss)
diss_t = enumerate(source_diss)
diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k] diss_m_rst = sorted(diss_t, key=lambda x: x[1])[:k]
diss_m_out = [id_ for _, id_ in diss_m_rst] diss_m_out = [id_ for _, id_ in diss_m_rst]
id_t = source_ids.extend(ids) source_ids.extend(ids)
id_m_out = [id_t[i] for i, _ in diss_m_rst] id_m_out = [source_ids[i] for i, _ in diss_m_rst]
return id_m_out, diss_m_out return id_m_out, diss_m_out
@ -50,8 +51,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
if not files_n_topk_results: if not files_n_topk_results:
return status, [] return status, []
# request_results = defaultdict(list)
# row_num = files_n_topk_results[0].row_num
merge_id_results = [] merge_id_results = []
merge_dis_results = [] merge_dis_results = []
@ -64,6 +63,7 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
row_num = files_collection.row_num row_num = files_collection.row_num
ids = files_collection.ids ids = files_collection.ids
diss = files_collection.distances # distance collections diss = files_collection.distances # distance collections
# TODO: batch_len is equal to topk
batch_len = len(ids) // row_num batch_len = len(ids) // row_num
for row_index in range(row_num): for row_index in range(row_num):
@ -77,28 +77,16 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
merge_id_results.append(id_batch) merge_id_results.append(id_batch)
merge_dis_results.append(dis_batch) merge_dis_results.append(dis_batch)
else: else:
merge_id_results[row_index].extend(ids[row_index * batch_len, (row_index + 1) * batch_len])
merge_dis_results[row_index].extend(diss[row_index * batch_len, (row_index + 1) * batch_len])
# _reduce(_ids, _diss, k, reverse)
merge_id_results[row_index], merge_dis_results[row_index] = \ merge_id_results[row_index], merge_dis_results[row_index] = \
self._reduce(merge_id_results[row_index], id_batch, self._reduce(merge_id_results[row_index], id_batch,
merge_dis_results[row_index], dis_batch, merge_dis_results[row_index], dis_batch,
batch_len, batch_len,
reverse) reverse)
# for request_pos, each_request_results in enumerate(
# files_collection.topk_query_result):
# request_results[request_pos].extend(
# each_request_results.query_result_arrays)
# request_results[request_pos] = sorted(
# request_results[request_pos],
# key=lambda x: x.distance,
# reverse=reverse)[:topk]
calc_time = time.time() - calc_time calc_time = time.time() - calc_time
logger.info('Merge takes {}'.format(calc_time)) logger.info('Merge takes {}'.format(calc_time))
# results = sorted(request_results.items())
id_mrege_list = [] id_mrege_list = []
dis_mrege_list = [] dis_mrege_list = []
@ -106,10 +94,6 @@ class ServiceHandler(milvus_pb2_grpc.MilvusServiceServicer):
id_mrege_list.extend(id_results) id_mrege_list.extend(id_results)
dis_mrege_list.extend(dis_results) dis_mrege_list.extend(dis_results)
# for result in results:
# query_result = TopKQueryResult(query_result_arrays=result[1])
# topk_query_result.append(query_result)
return status, id_mrege_list, dis_mrege_list return status, id_mrege_list, dis_mrege_list
def _do_query(self, def _do_query(self,