delete knowhere

Former-commit-id: c04ad4797de102962ee39c6cf364926a120ddcd8
This commit is contained in:
kun yu 2019-08-07 20:01:44 +08:00
parent fe165a6165
commit cc5ab807fc
400 changed files with 0 additions and 125767 deletions

View File

@ -1,67 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
#define _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
#include "inc/Socket/Common.h"
#include "AggregatorSettings.h"
#include <memory>
#include <vector>
#include <atomic>
namespace SPTAG
{
namespace Aggregator
{
enum RemoteMachineStatus : uint8_t
{
Disconnected = 0,
Connecting,
Connected
};
struct RemoteMachine
{
RemoteMachine();
std::string m_address;
std::string m_port;
Socket::ConnectionID m_connectionID;
std::atomic<RemoteMachineStatus> m_status;
};
class AggregatorContext
{
public:
AggregatorContext(const std::string& p_filePath);
~AggregatorContext();
bool IsInitialized() const;
const std::vector<std::shared_ptr<RemoteMachine>>& GetRemoteServers() const;
const std::shared_ptr<AggregatorSettings>& GetSettings() const;
private:
std::vector<std::shared_ptr<RemoteMachine>> m_remoteServers;
std::shared_ptr<AggregatorSettings> m_settings;
bool m_initialized;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_

View File

@ -1,53 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
#define _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
#include "inc/Socket/RemoteSearchQuery.h"
#include "inc/Socket/Packet.h"
#include <memory>
#include <atomic>
namespace SPTAG
{
namespace Aggregator
{
typedef std::shared_ptr<Socket::RemoteSearchResult> AggregatorResult;
class AggregatorExecutionContext
{
public:
AggregatorExecutionContext(std::size_t p_totalServerNumber,
Socket::PacketHeader p_requestHeader);
~AggregatorExecutionContext();
std::size_t GetServerNumber() const;
AggregatorResult& GetResult(std::size_t p_num);
const Socket::PacketHeader& GetRequestHeader() const;
bool IsCompletedAfterFinsh(std::uint32_t p_finishedCount);
private:
std::atomic<std::uint32_t> m_unfinishedCount;
std::vector<AggregatorResult> m_results;
Socket::PacketHeader m_requestHeader;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_

View File

@ -1,88 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
#define _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
#include "AggregatorContext.h"
#include "AggregatorExecutionContext.h"
#include "inc/Socket/Server.h"
#include "inc/Socket/Client.h"
#include "inc/Socket/ResourceManager.h"
#include <boost/asio.hpp>
#include <memory>
#include <vector>
#include <thread>
#include <condition_variable>
namespace SPTAG
{
namespace Aggregator
{
class AggregatorService
{
public:
AggregatorService();
~AggregatorService();
bool Initialize();
void Run();
private:
void StartClient();
void StartListen();
void WaitForShutdown();
void ConnectToPendingServers();
void AddToPendingServers(std::shared_ptr<RemoteMachine> p_remoteServer);
void SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void AggregateResults(std::shared_ptr<AggregatorExecutionContext> p_exectionContext);
std::shared_ptr<AggregatorContext> GetContext();
private:
typedef std::function<void(Socket::RemoteSearchResult)> AggregatorCallback;
std::shared_ptr<AggregatorContext> m_aggregatorContext;
std::shared_ptr<Socket::Server> m_socketServer;
std::shared_ptr<Socket::Client> m_socketClient;
bool m_initalized;
std::unique_ptr<boost::asio::thread_pool> m_threadPool;
boost::asio::io_context m_ioContext;
boost::asio::signal_set m_shutdownSignals;
std::vector<std::shared_ptr<RemoteMachine>> m_pendingConnectServers;
std::mutex m_pendingConnectServersMutex;
boost::asio::deadline_timer m_pendingConnectServersTimer;
Socket::ResourceManager<AggregatorCallback> m_aggregatorCallbackManager;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_

View File

@ -1,39 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
#define _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
#include "../Core/Common.h"
#include <string>
namespace SPTAG
{
namespace Aggregator
{
struct AggregatorSettings
{
AggregatorSettings();
std::string m_listenAddr;
std::string m_listenPort;
std::uint32_t m_searchTimeout;
SizeType m_threadNum;
SizeType m_socketThreadNum;
};
} // namespace Aggregator
} // namespace AnnService
#endif // _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_

View File

@ -1,80 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CLIENT_CLIENTWRAPPER_H_
#define _SPTAG_CLIENT_CLIENTWRAPPER_H_
#include "inc/Socket/Client.h"
#include "inc/Socket/RemoteSearchQuery.h"
#include "inc/Socket/ResourceManager.h"
#include "Options.h"
#include <string>
#include <vector>
#include <memory>
#include <atomic>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <functional>
namespace SPTAG
{
namespace Client
{
class ClientWrapper
{
public:
typedef std::function<void(Socket::RemoteSearchResult)> Callback;
ClientWrapper(const ClientOptions& p_options);
~ClientWrapper();
void SendQueryAsync(const Socket::RemoteQuery& p_query,
Callback p_callback,
const ClientOptions& p_options);
void WaitAllFinished();
bool IsAvailable() const;
private:
typedef std::pair<Socket::ConnectionID, Socket::ConnectionID> ConnectionPair;
Socket::PacketHandlerMapPtr GetHandlerMap();
void DecreaseUnfnishedJobCount();
const ConnectionPair& GetConnection();
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void HandleDeadConnection(Socket::ConnectionID p_cid);
private:
ClientOptions m_options;
std::unique_ptr<Socket::Client> m_client;
std::atomic<std::uint32_t> m_unfinishedJobCount;
std::atomic_bool m_isWaitingFinish;
std::condition_variable m_waitingQueue;
std::mutex m_waitingMutex;
std::vector<ConnectionPair> m_connections;
std::atomic<std::uint32_t> m_spinCountOfConnection;
Socket::ResourceManager<Callback> m_callbackManager;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_CLIENT_OPTIONS_H_

View File

@ -1,42 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CLIENT_OPTIONS_H_
#define _SPTAG_CLIENT_OPTIONS_H_
#include "inc/Helper/ArgumentsParser.h"
#include <string>
#include <vector>
#include <memory>
namespace SPTAG
{
namespace Client
{
class ClientOptions : public Helper::ArgumentsParser
{
public:
ClientOptions();
virtual ~ClientOptions();
std::string m_serverAddr;
std::string m_serverPort;
// in milliseconds.
std::uint32_t m_searchTimeout;
std::uint32_t m_threadNum;
std::uint32_t m_socketThreadNum;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_CLIENT_OPTIONS_H_

View File

@ -1,112 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_BKT_INDEX_H_
#define _SPTAG_BKT_INDEX_H_
#include "../Common.h"
#include "../VectorIndex.h"
#include "../Common/CommonUtils.h"
#include "../Common/DistanceUtils.h"
#include "../Common/QueryResultSet.h"
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/BKTree.h"
#include "inc/Helper/SimpleIniReader.h"
#include "inc/Helper/StringConvert.h"
#include <functional>
#include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG
{
namespace Helper
{
class IniReader;
}
namespace BKT
{
template<typename T>
class Index : public VectorIndex
{
private:
// data points
COMMON::Dataset<T> m_pSamples;
// BKT structures.
COMMON::BKTree m_pTrees;
// Graph structure
COMMON::RelativeNeighborhoodGraph m_pGraph;
std::string m_sBKTFilename;
std::string m_sGraphFilename;
std::string m_sDataPointsFilename;
std::mutex m_dataLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
int m_iNumberOfInitialDynamicPivots;
int m_iNumberOfOtherDynamicPivots;
public:
Index()
{
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \
#include "inc/Core/BKT/ParameterDefinitionList.h"
#undef DefineBKTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; }
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; }
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
};
} // namespace BKT
} // namespace SPTAG
#endif // _SPTAG_BKT_INDEX_H_

View File

@ -1,36 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineBKTParameter
// DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr)
DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber")
DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TpTreeNumber")
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit")
DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
DefineBKTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
DefineBKTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
#endif

View File

@ -1,166 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_CORE_COMMONDEFS_H_
#define _SPTAG_CORE_COMMONDEFS_H_
#include <cstdint>
#include <type_traits>
#include <memory>
#include <string>
#include <limits>
#include <vector>
#include <cmath>
#ifndef _MSC_VER
#include <sys/stat.h>
#include <sys/types.h>
#define FolderSep '/'
#define mkdir(a) mkdir(a, ACCESSPERMS)
inline bool direxists(const char* path) {
struct stat info;
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR);
}
inline bool fileexists(const char* path) {
struct stat info;
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR) == 0;
}
template <class T>
inline T min(T a, T b) {
return a < b ? a : b;
}
template <class T>
inline T max(T a, T b) {
return a > b ? a : b;
}
#ifndef _rotl
#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n))))
#endif
#else
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#include <Psapi.h>
#define FolderSep '\\'
#define mkdir(a) CreateDirectory(a, NULL)
inline bool direxists(const char* path) {
auto dwAttr = GetFileAttributes((LPCSTR)path);
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY);
}
inline bool fileexists(const char* path) {
auto dwAttr = GetFileAttributes((LPCSTR)path);
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY) == 0;
}
#endif
namespace SPTAG
{
typedef std::uint32_t SizeType;
const float MinDist = (std::numeric_limits<float>::min)();
const float MaxDist = (std::numeric_limits<float>::max)();
const float Epsilon = 0.000000001f;
class MyException : public std::exception
{
private:
std::string Exp;
public:
MyException(std::string e) { Exp = e; }
#ifdef _MSC_VER
const char* what() const { return Exp.c_str(); }
#else
const char* what() const noexcept { return Exp.c_str(); }
#endif
};
// Type of number index.
typedef std::int32_t IndexType;
static_assert(std::is_integral<IndexType>::value, "IndexType must be integral type.");
enum class ErrorCode : std::uint16_t
{
#define DefineErrorCode(Name, Value) Name = Value,
#include "DefinitionList.h"
#undef DefineErrorCode
Undefined
};
static_assert(static_cast<std::uint16_t>(ErrorCode::Undefined) != 0, "Empty ErrorCode!");
enum class DistCalcMethod : std::uint8_t
{
#define DefineDistCalcMethod(Name) Name,
#include "DefinitionList.h"
#undef DefineDistCalcMethod
Undefined
};
static_assert(static_cast<std::uint8_t>(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!");
enum class VectorValueType : std::uint8_t
{
#define DefineVectorValueType(Name, Type) Name,
#include "DefinitionList.h"
#undef DefineVectorValueType
Undefined
};
static_assert(static_cast<std::uint8_t>(VectorValueType::Undefined) != 0, "Empty VectorValueType!");
enum class IndexAlgoType : std::uint8_t
{
#define DefineIndexAlgo(Name) Name,
#include "DefinitionList.h"
#undef DefineIndexAlgo
Undefined
};
static_assert(static_cast<std::uint8_t>(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!");
template<typename T>
constexpr VectorValueType GetEnumValueType()
{
return VectorValueType::Undefined;
}
#define DefineVectorValueType(Name, Type) \
template<> \
constexpr VectorValueType GetEnumValueType<Type>() \
{ \
return VectorValueType::Name; \
} \
#include "DefinitionList.h"
#undef DefineVectorValueType
inline std::size_t GetValueTypeSize(VectorValueType p_valueType)
{
switch (p_valueType)
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
return sizeof(Type); \
#include "DefinitionList.h"
#undef DefineVectorValueType
default:
break;
}
return 0;
}
} // namespace SPTAG
#endif // _SPTAG_CORE_COMMONDEFS_H_

View File

@ -1,492 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_BKTREE_H_
#define _SPTAG_COMMON_BKTREE_H_
#include <iostream>
#include <stack>
#include <string>
#include <vector>
#include "../VectorIndex.h"
#include "CommonUtils.h"
#include "QueryResultSet.h"
#include "WorkSpace.h"
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// node type for storing BKT
struct BKTNode
{
int centerid;
int childStart;
int childEnd;
BKTNode(int cid = -1) : centerid(cid), childStart(-1), childEnd(-1) {}
};
template <typename T>
struct KmeansArgs {
int _K;
int _D;
int _T;
T* centers;
int* counts;
float* newCenters;
int* newCounts;
char* label;
int* clusterIdx;
float* clusterDist;
T* newTCenters;
KmeansArgs(int k, int dim, int datasize, int threadnum) : _K(k), _D(dim), _T(threadnum) {
centers = new T[k * dim];
counts = new int[k];
newCenters = new float[threadnum * k * dim];
newCounts = new int[threadnum * k];
label = new char[datasize];
clusterIdx = new int[threadnum * k];
clusterDist = new float[threadnum * k];
newTCenters = new T[k * dim];
}
~KmeansArgs() {
delete[] centers;
delete[] counts;
delete[] newCenters;
delete[] newCounts;
delete[] label;
delete[] clusterIdx;
delete[] clusterDist;
delete[] newTCenters;
}
inline void ClearCounts() {
memset(newCounts, 0, sizeof(int) * _T * _K);
}
inline void ClearCenters() {
memset(newCenters, 0, sizeof(float) * _T * _K * _D);
}
inline void ClearDists(float dist) {
for (int i = 0; i < _T * _K; i++) {
clusterIdx[i] = -1;
clusterDist[i] = dist;
}
}
void Shuffle(std::vector<int>& indices, int first, int last) {
int* pos = new int[_K];
pos[0] = first;
for (int k = 1; k < _K; k++) pos[k] = pos[k - 1] + newCounts[k - 1];
for (int k = 0; k < _K; k++) {
if (newCounts[k] == 0) continue;
int i = pos[k];
while (newCounts[k] > 0) {
int swapid = pos[(int)(label[i])] + newCounts[(int)(label[i])] - 1;
newCounts[(int)(label[i])]--;
std::swap(indices[i], indices[swapid]);
std::swap(label[i], label[swapid]);
}
while (indices[i] != clusterIdx[k]) i++;
std::swap(indices[i], indices[pos[k] + counts[k] - 1]);
}
delete[] pos;
}
};
class BKTree
{
public:
BKTree(): m_iTreeNumber(1), m_iBKTKmeansK(32), m_iBKTLeafSize(8), m_iSamples(1000) {}
BKTree(BKTree& other): m_iTreeNumber(other.m_iTreeNumber),
m_iBKTKmeansK(other.m_iBKTKmeansK),
m_iBKTLeafSize(other.m_iBKTLeafSize),
m_iSamples(other.m_iSamples) {}
~BKTree() {}
inline const BKTNode& operator[](int index) const { return m_pTreeRoots[index]; }
inline BKTNode& operator[](int index) { return m_pTreeRoots[index]; }
inline int size() const { return (int)m_pTreeRoots.size(); }
inline const std::unordered_map<int, int>& GetSampleMap() const { return m_pSampleCenterMap; }
template <typename T>
void BuildTrees(VectorIndex* index, std::vector<int>* indices = nullptr)
{
struct BKTStackItem {
int index, first, last;
BKTStackItem(int index_, int first_, int last_) : index(index_), first(first_), last(last_) {}
};
std::stack<BKTStackItem> ss;
std::vector<int> localindices;
if (indices == nullptr) {
localindices.resize(index->GetNumSamples());
for (int i = 0; i < index->GetNumSamples(); i++) localindices[i] = i;
}
else {
localindices.assign(indices->begin(), indices->end());
}
KmeansArgs<T> args(m_iBKTKmeansK, index->GetFeatureDim(), (int)localindices.size(), omp_get_num_threads());
m_pSampleCenterMap.clear();
for (char i = 0; i < m_iTreeNumber; i++)
{
std::random_shuffle(localindices.begin(), localindices.end());
m_pTreeStart.push_back((int)m_pTreeRoots.size());
m_pTreeRoots.push_back(BKTNode((int)localindices.size()));
std::cout << "Start to build BKTree " << i + 1 << std::endl;
ss.push(BKTStackItem(m_pTreeStart[i], 0, (int)localindices.size()));
while (!ss.empty()) {
BKTStackItem item = ss.top(); ss.pop();
int newBKTid = (int)m_pTreeRoots.size();
m_pTreeRoots[item.index].childStart = newBKTid;
if (item.last - item.first <= m_iBKTLeafSize) {
for (int j = item.first; j < item.last; j++) {
m_pTreeRoots.push_back(BKTNode(localindices[j]));
}
}
else { // clustering the data into BKTKmeansK clusters
int numClusters = KmeansClustering(index, localindices, item.first, item.last, args);
if (numClusters <= 1) {
int end = min(item.last + 1, (int)localindices.size());
std::sort(localindices.begin() + item.first, localindices.begin() + end);
m_pTreeRoots[item.index].centerid = localindices[item.first];
m_pTreeRoots[item.index].childStart = -m_pTreeRoots[item.index].childStart;
for (int j = item.first + 1; j < end; j++) {
m_pTreeRoots.push_back(BKTNode(localindices[j]));
m_pSampleCenterMap[localindices[j]] = m_pTreeRoots[item.index].centerid;
}
m_pSampleCenterMap[-1 - m_pTreeRoots[item.index].centerid] = item.index;
}
else {
for (int k = 0; k < m_iBKTKmeansK; k++) {
if (args.counts[k] == 0) continue;
m_pTreeRoots.push_back(BKTNode(localindices[item.first + args.counts[k] - 1]));
if (args.counts[k] > 1) ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1));
item.first += args.counts[k];
}
}
}
m_pTreeRoots[item.index].childEnd = (int)m_pTreeRoots.size();
}
std::cout << i + 1 << " BKTree built, " << m_pTreeRoots.size() - m_pTreeStart[i] << " " << localindices.size() << std::endl;
}
}
bool SaveTrees(void **pKDTMemFile, int64_t &len) const
{
int treeNodeSize = (int)m_pTreeRoots.size();
size_t size = sizeof(int) +
sizeof(int) * m_iTreeNumber +
sizeof(int) +
sizeof(BKTNode) * treeNodeSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iTreeNumber;
ptr += sizeof(int);
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
ptr += sizeof(int) * m_iTreeNumber;
*(int*)ptr = treeNodeSize;
ptr += sizeof(int);
memcpy(ptr, m_pTreeRoots.data(), sizeof(BKTNode) * treeNodeSize);
*pKDTMemFile = mem;
len = size;
return true;
}
bool SaveTrees(std::string sTreeFileName) const
{
std::cout << "Save BKT to " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iTreeNumber, sizeof(int), 1, fp);
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize = (int)m_pTreeRoots.size();
fwrite(&treeNodeSize, sizeof(int), 1, fp);
fwrite(m_pTreeRoots.data(), sizeof(BKTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Save BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
bool LoadTrees(char* pBKTMemFile)
{
m_iTreeNumber = *((int*)pBKTMemFile);
pBKTMemFile += sizeof(int);
m_pTreeStart.resize(m_iTreeNumber);
memcpy(m_pTreeStart.data(), pBKTMemFile, sizeof(int) * m_iTreeNumber);
pBKTMemFile += sizeof(int)*m_iTreeNumber;
int treeNodeSize = *((int*)pBKTMemFile);
pBKTMemFile += sizeof(int);
m_pTreeRoots.resize(treeNodeSize);
memcpy(m_pTreeRoots.data(), pBKTMemFile, sizeof(BKTNode) * treeNodeSize);
return true;
}
bool LoadTrees(std::string sTreeFileName)
{
std::cout << "Load BKT From " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iTreeNumber, sizeof(int), 1, fp);
m_pTreeStart.resize(m_iTreeNumber);
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize;
fread(&treeNodeSize, sizeof(int), 1, fp);
m_pTreeRoots.resize(treeNodeSize);
fread(m_pTreeRoots.data(), sizeof(BKTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Load BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
template <typename T>
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const
{
for (char i = 0; i < m_iTreeNumber; i++) {
const BKTNode& node = m_pTreeRoots[m_pTreeStart[i]];
if (node.childStart < 0) {
p_space.m_SPTQueue.insert(COMMON::HeapCell(m_pTreeStart[i], p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(node.centerid))));
}
else {
for (int begin = node.childStart; begin < node.childEnd; begin++) {
int index = m_pTreeRoots[begin].centerid;
p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(index))));
}
}
}
}
template <typename T>
void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
COMMON::WorkSpace &p_space, const int p_limits) const
{
do
{
COMMON::HeapCell bcell = p_space.m_SPTQueue.pop();
const BKTNode& tnode = m_pTreeRoots[bcell.node];
if (tnode.childStart < 0) {
if (!p_space.CheckAndSet(tnode.centerid)) {
p_space.m_iNumberOfCheckedLeaves++;
p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
}
if (p_space.m_iNumberOfCheckedLeaves >= p_limits) break;
}
else {
if (!p_space.CheckAndSet(tnode.centerid)) {
p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
}
for (int begin = tnode.childStart; begin < tnode.childEnd; begin++) {
int index = m_pTreeRoots[begin].centerid;
p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(index))));
}
}
} while (!p_space.m_SPTQueue.empty());
}
private:
template <typename T>
float KmeansAssign(VectorIndex* p_index,
std::vector<int>& indices,
const int first, const int last, KmeansArgs<T>& args, const bool updateCenters) const {
float currDist = 0;
int threads = omp_get_num_threads();
float lambda = (updateCenters) ? COMMON::Utils::GetBase<T>() * COMMON::Utils::GetBase<T>() / (100.0f * (last - first)) : 0.0f;
int subsize = (last - first - 1) / threads + 1;
#pragma omp parallel for
for (int tid = 0; tid < threads; tid++)
{
int istart = first + tid * subsize;
int iend = min(first + (tid + 1) * subsize, last);
int *inewCounts = args.newCounts + tid * m_iBKTKmeansK;
float *inewCenters = args.newCenters + tid * m_iBKTKmeansK * p_index->GetFeatureDim();
int * iclusterIdx = args.clusterIdx + tid * m_iBKTKmeansK;
float * iclusterDist = args.clusterDist + tid * m_iBKTKmeansK;
float idist = 0;
for (int i = istart; i < iend; i++) {
int clusterid = 0;
float smallestDist = MaxDist;
for (int k = 0; k < m_iBKTKmeansK; k++) {
float dist = p_index->ComputeDistance(p_index->GetSample(indices[i]), (const void*)(args.centers + k*p_index->GetFeatureDim())) + lambda*args.counts[k];
if (dist > -MaxDist && dist < smallestDist) {
clusterid = k; smallestDist = dist;
}
}
args.label[i] = clusterid;
inewCounts[clusterid]++;
idist += smallestDist;
if (updateCenters) {
const T* v = (const T*)p_index->GetSample(indices[i]);
float* center = inewCenters + clusterid*p_index->GetFeatureDim();
for (int j = 0; j < p_index->GetFeatureDim(); j++) center[j] += v[j];
if (smallestDist > iclusterDist[clusterid]) {
iclusterDist[clusterid] = smallestDist;
iclusterIdx[clusterid] = indices[i];
}
}
else {
if (smallestDist <= iclusterDist[clusterid]) {
iclusterDist[clusterid] = smallestDist;
iclusterIdx[clusterid] = indices[i];
}
}
}
COMMON::Utils::atomic_float_add(&currDist, idist);
}
for (int i = 1; i < threads; i++) {
for (int k = 0; k < m_iBKTKmeansK; k++)
args.newCounts[k] += args.newCounts[i*m_iBKTKmeansK + k];
}
if (updateCenters) {
for (int i = 1; i < threads; i++) {
float* currCenter = args.newCenters + i*m_iBKTKmeansK*p_index->GetFeatureDim();
for (int j = 0; j < m_iBKTKmeansK * p_index->GetFeatureDim(); j++) args.newCenters[j] += currCenter[j];
}
int maxcluster = 0;
for (int k = 1; k < m_iBKTKmeansK; k++) if (args.newCounts[maxcluster] < args.newCounts[k]) maxcluster = k;
int maxid = maxcluster;
for (int tid = 1; tid < threads; tid++) {
if (args.clusterDist[maxid] < args.clusterDist[tid * m_iBKTKmeansK + maxcluster]) maxid = tid * m_iBKTKmeansK + maxcluster;
}
if (args.clusterIdx[maxid] < 0 || args.clusterIdx[maxid] >= p_index->GetNumSamples())
std::cout << "first:" << first << " last:" << last << " maxcluster:" << maxcluster << "(" << args.newCounts[maxcluster] << ") Error maxid:" << maxid << " dist:" << args.clusterDist[maxid] << std::endl;
maxid = args.clusterIdx[maxid];
for (int k = 0; k < m_iBKTKmeansK; k++) {
T* TCenter = args.newTCenters + k * p_index->GetFeatureDim();
if (args.newCounts[k] == 0) {
//int nextid = Utils::rand_int(last, first);
//while (args.label[nextid] != maxcluster) nextid = Utils::rand_int(last, first);
int nextid = maxid;
std::memcpy(TCenter, p_index->GetSample(nextid), sizeof(T)*p_index->GetFeatureDim());
}
else {
float* currCenters = args.newCenters + k * p_index->GetFeatureDim();
for (int j = 0; j < p_index->GetFeatureDim(); j++) currCenters[j] /= args.newCounts[k];
if (p_index->GetDistCalcMethod() == DistCalcMethod::Cosine) {
COMMON::Utils::Normalize(currCenters, p_index->GetFeatureDim(), COMMON::Utils::GetBase<T>());
}
for (int j = 0; j < p_index->GetFeatureDim(); j++) TCenter[j] = (T)(currCenters[j]);
}
}
}
else {
for (int i = 1; i < threads; i++) {
for (int k = 0; k < m_iBKTKmeansK; k++) {
if (args.clusterIdx[i*m_iBKTKmeansK + k] != -1 && args.clusterDist[i*m_iBKTKmeansK + k] <= args.clusterDist[k]) {
args.clusterDist[k] = args.clusterDist[i*m_iBKTKmeansK + k];
args.clusterIdx[k] = args.clusterIdx[i*m_iBKTKmeansK + k];
}
}
}
}
return currDist;
}
template <typename T>
int KmeansClustering(VectorIndex* p_index,
std::vector<int>& indices, const int first, const int last, KmeansArgs<T>& args) const {
int iterLimit = 100;
int batchEnd = min(first + m_iSamples, last);
float currDiff, currDist, minClusterDist = MaxDist;
for (int numKmeans = 0; numKmeans < 3; numKmeans++) {
for (int k = 0; k < m_iBKTKmeansK; k++) {
int randid = COMMON::Utils::rand_int(last, first);
std::memcpy(args.centers + k*p_index->GetFeatureDim(), p_index->GetSample(indices[randid]), sizeof(T)*p_index->GetFeatureDim());
}
args.ClearCounts();
currDist = KmeansAssign(p_index, indices, first, batchEnd, args, false);
if (currDist < minClusterDist) {
minClusterDist = currDist;
memcpy(args.newTCenters, args.centers, sizeof(T)*m_iBKTKmeansK*p_index->GetFeatureDim());
memcpy(args.counts, args.newCounts, sizeof(int) * m_iBKTKmeansK);
}
}
minClusterDist = MaxDist;
int noImprovement = 0;
for (int iter = 0; iter < iterLimit; iter++) {
std::memcpy(args.centers, args.newTCenters, sizeof(T)*m_iBKTKmeansK*p_index->GetFeatureDim());
std::random_shuffle(indices.begin() + first, indices.begin() + last);
args.ClearCenters();
args.ClearCounts();
args.ClearDists(-MaxDist);
currDist = KmeansAssign(p_index, indices, first, batchEnd, args, true);
memcpy(args.counts, args.newCounts, sizeof(int)*m_iBKTKmeansK);
currDiff = 0;
for (int k = 0; k < m_iBKTKmeansK; k++) {
currDiff += p_index->ComputeDistance((const void*)(args.centers + k*p_index->GetFeatureDim()), (const void*)(args.newTCenters + k*p_index->GetFeatureDim()));
}
if (currDist < minClusterDist) {
noImprovement = 0;
minClusterDist = currDist;
}
else {
noImprovement++;
}
if (currDiff < 1e-3 || noImprovement >= 5) break;
}
args.ClearCounts();
args.ClearDists(MaxDist);
currDist = KmeansAssign(p_index, indices, first, last, args, false);
memcpy(args.counts, args.newCounts, sizeof(int)*m_iBKTKmeansK);
int numClusters = 0;
for (int i = 0; i < m_iBKTKmeansK; i++) if (args.counts[i] > 0) numClusters++;
if (numClusters <= 1) {
//if (last - first > 1) std::cout << "large cluster:" << last - first << " dist:" << currDist << std::endl;
return numClusters;
}
args.Shuffle(indices, first, last);
return numClusters;
}
private:
std::vector<int> m_pTreeStart;
std::vector<BKTNode> m_pTreeRoots;
std::unordered_map<int, int> m_pSampleCenterMap;
public:
int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples;
};
}
}
#endif

View File

@ -1,178 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_COMMONUTILS_H_
#define _SPTAG_COMMON_COMMONUTILS_H_
#include "../Common.h"
#include <unordered_map>
#include <fstream>
#include <iostream>
#include <exception>
#include <algorithm>
#include <time.h>
#include <omp.h>
#include <string.h>
#define PREFETCH
#ifndef _MSC_VER
#include <stdio.h>
#include <unistd.h>
#include <sys/resource.h>
#include <cstring>
#define InterlockedCompareExchange(a,b,c) __sync_val_compare_and_swap(a, c, b)
#define Sleep(a) usleep(a * 1000)
#define strtok_s(a, b, c) strtok_r(a, b, c)
#endif
namespace SPTAG
{
namespace COMMON
{
class Utils {
public:
static int rand_int(int high = RAND_MAX, int low = 0) // Generates a random int value.
{
return low + (int)(float(high - low)*(std::rand() / (RAND_MAX + 1.0)));
}
static inline float atomic_float_add(volatile float* ptr, const float operand)
{
union {
volatile long iOld;
float fOld;
};
union {
long iNew;
float fNew;
};
while (true) {
iOld = *(volatile long *)ptr;
fNew = fOld + operand;
if (InterlockedCompareExchange((long *)ptr, iNew, iOld) == iOld) {
return fNew;
}
}
}
static double GetVector(char* cstr, const char* sep, std::vector<float>& arr, int& NumDim) {
char* current;
char* context = NULL;
int i = 0;
double sum = 0;
arr.clear();
current = strtok_s(cstr, sep, &context);
while (current != NULL && (i < NumDim || NumDim < 0)) {
try {
float val = (float)atof(current);
arr.push_back(val);
}
catch (std::exception e) {
std::cout << "Exception:" << e.what() << std::endl;
return -2;
}
sum += arr[i] * arr[i];
current = strtok_s(NULL, sep, &context);
i++;
}
if (NumDim < 0) NumDim = i;
if (i < NumDim) return -2;
return std::sqrt(sum);
}
template <typename T>
static void Normalize(T* arr, int col, int base) {
double vecLen = 0;
for (int j = 0; j < col; j++) {
double val = arr[j];
vecLen += val * val;
}
vecLen = std::sqrt(vecLen);
if (vecLen < 1e-6) {
T val = (T)(1.0 / std::sqrt((double)col) * base);
for (int j = 0; j < col; j++) arr[j] = val;
}
else {
for (int j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base);
}
}
static size_t ProcessLine(std::string& currentLine, std::vector<float>& arr, int& D, int base, DistCalcMethod distCalcMethod) {
size_t index;
double vecLen;
if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast<char*>(currentLine.c_str() + index + 1), "|", arr, D)) < -1) {
std::cout << "Parse vector error: " + currentLine << std::endl;
//throw MyException("Error in parsing data " + currentLine);
return -1;
}
if (distCalcMethod == DistCalcMethod::Cosine) {
Normalize(arr.data(), D, base);
}
return index;
}
template <typename T>
static void PrepareQuerys(std::ifstream& inStream, std::vector<std::string>& qString, std::vector<std::vector<T>>& Query, int& NumQuery, int& NumDim, DistCalcMethod distCalcMethod, int base) {
std::string currentLine;
std::vector<float> arr;
int i = 0;
size_t index;
while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) {
std::getline(inStream, currentLine);
if (currentLine.length() <= 1 || (index = ProcessLine(currentLine, arr, NumDim, base, distCalcMethod)) < 0) {
continue;
}
qString.push_back(currentLine.substr(0, index));
if (Query.size() < i + 1) Query.push_back(std::vector<T>(NumDim, 0));
for (int j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j];
i++;
}
NumQuery = i;
std::cout << "Load data: (" << NumQuery << ", " << NumDim << ")" << std::endl;
}
template<typename T>
static inline int GetBase() {
if (GetEnumValueType<T>() != VectorValueType::Float) {
return (int)(std::numeric_limits<T>::max)();
}
return 1;
}
static inline void AddNeighbor(int idx, float dist, int *neighbors, float *dists, int size)
{
size--;
if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size]))
{
int nb;
for (nb = 0; nb <= size && neighbors[nb] != idx; nb++);
if (nb > size)
{
nb = size;
while (nb > 0 && (dist < dists[nb - 1] || (dist == dists[nb - 1] && idx < neighbors[nb - 1])))
{
dists[nb] = dists[nb - 1];
neighbors[nb] = neighbors[nb - 1];
nb--;
}
dists[nb] = dist;
neighbors[nb] = idx;
}
}
}
};
}
}
#endif // _SPTAG_COMMON_COMMONUTILS_H_

View File

@ -1,290 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DATAUTILS_H_
#define _SPTAG_COMMON_DATAUTILS_H_
#include <sys/stat.h>
#include <atomic>
#include "CommonUtils.h"
#include "../../Helper/CommonHelper.h"
namespace SPTAG
{
namespace COMMON
{
const int bufsize = 1024 * 1024 * 1024;
class DataUtils {
public:
template <typename T>
static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize,
std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
std::ifstream inputStream(filename);
if (!inputStream.is_open()) {
std::cerr << "unable to open file " + filename << std::endl;
throw MyException("unable to open file " + filename);
exit(1);
}
std::ofstream outputStream, metaStream_out, metaStream_index;
outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
throw MyException("unable to open output files");
exit(1);
}
std::vector<float> arr;
std::vector<T> sample;
int base = 1;
if (distCalcMethod == DistCalcMethod::Cosine) {
base = Utils::GetBase<T>();
}
std::uint64_t writepos = 0;
int sampleSize = 0;
std::uint64_t totalread = 0;
std::streamoff startpos = id * blocksize;
#ifndef _MSC_VER
int enter_size = 1;
#else
int enter_size = 1;
#endif
std::string currentLine;
size_t index;
inputStream.seekg(startpos, std::ifstream::beg);
if (id != 0) {
std::getline(inputStream, currentLine);
totalread += currentLine.length() + enter_size;
}
std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
while (!inputStream.eof() && totalread <= blocksize) {
std::getline(inputStream, currentLine);
if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
totalread += currentLine.length() + enter_size;
continue;
}
sample.resize(D);
for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
outputStream.write((char *)(sample.data()), sizeof(T)*D);
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_out.write(currentLine.c_str(), index);
writepos += index;
sampleSize += 1;
totalread += currentLine.length() + enter_size;
}
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
metaStream_index.write((char *)&sampleSize, sizeof(int));
inputStream.close();
outputStream.close();
metaStream_out.close();
metaStream_index.close();
numSamples.fetch_add(sampleSize);
std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
}
static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
std::atomic_int& numSamples, int D) {
std::ifstream inputStream;
std::ofstream outputStream;
char * buf = new char[bufsize];
std::uint64_t * offsets;
int partSamples;
int metaSamples = 0;
std::uint64_t lastoff = 0;
outputStream.open(outfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
outputStream.write((char *)&D, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
outputStream.open(outmetafile, std::ofstream::binary);
for (int i = 0; i < threadbase; i++) {
std::string file = outmetafile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
while (!inputStream.eof()) {
inputStream.read(buf, bufsize);
outputStream.write(buf, inputStream.gcount());
}
inputStream.close();
remove(file.c_str());
}
outputStream.close();
delete[] buf;
outputStream.open(outmetaindexfile, std::ofstream::binary);
outputStream.write((char *)&numSamples, sizeof(int));
for (int i = 0; i < threadbase; i++) {
std::string file = outmetaindexfile + std::to_string(i);
inputStream.open(file, std::ifstream::binary);
inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
inputStream.read((char *)&partSamples, sizeof(int));
offsets = new std::uint64_t[partSamples + 1];
inputStream.seekg(0, inputStream.beg);
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1));
inputStream.close();
remove(file.c_str());
for (int j = 0; j < partSamples + 1; j++)
offsets[j] += lastoff;
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples);
lastoff = offsets[partSamples];
metaSamples += partSamples;
delete[] offsets;
}
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
}
static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1,
const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) {
std::ifstream inputStream1, inputStream2;
std::ofstream outputStream;
char * buf = new char[bufsize];
int R1, R2, C1, C2;
#define MergeVector(inputStream, vectorFile, R, C) \
inputStream.open(vectorFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \
return false; \
} \
inputStream.read((char *)&(R), sizeof(int)); \
inputStream.read((char *)&(C), sizeof(int)); \
MergeVector(inputStream1, p_vectorfile1, R1, C1)
MergeVector(inputStream2, p_vectorfile2, R2, C2)
#undef MergeVector
if (C1 != C2) {
inputStream1.close(); inputStream2.close();
std::cout << "Vector dimensions are not the same!" << std::endl;
return false;
}
R1 += R2;
outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int));
outputStream.write((char *)&C1, sizeof(int));
while (!inputStream1.eof()) {
inputStream1.read(buf, bufsize);
outputStream.write(buf, inputStream1.gcount());
}
while (!inputStream2.eof()) {
inputStream2.read(buf, bufsize);
outputStream.write(buf, inputStream2.gcount());
}
inputStream1.close(); inputStream2.close();
outputStream.close();
if (p_metafile1 != "" && p_metafile2 != "") {
outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary);
#define MergeMeta(inputStream, metaFile) \
inputStream.open(metaFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \
return false; \
} \
while (!inputStream.eof()) { \
inputStream.read(buf, bufsize); \
outputStream.write(buf, inputStream.gcount()); \
} \
inputStream.close(); \
MergeMeta(inputStream1, p_metafile1)
MergeMeta(inputStream2, p_metafile2)
#undef MergeMeta
outputStream.close();
delete[] buf;
std::uint64_t * offsets;
int partSamples;
std::uint64_t lastoff = 0;
outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary);
outputStream.write((char *)&R1, sizeof(int));
#define MergeMetaIndex(inputStream, metaIndexFile) \
inputStream.open(metaIndexFile, std::ifstream::binary); \
if (!inputStream.is_open()) { \
std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \
return false; \
} \
inputStream.read((char *)&partSamples, sizeof(int)); \
offsets = new std::uint64_t[partSamples + 1]; \
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \
inputStream.close(); \
for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \
lastoff = offsets[partSamples]; \
delete[] offsets; \
MergeMetaIndex(inputStream1, p_metaindexfile1)
MergeMetaIndex(inputStream2, p_metaindexfile2)
#undef MergeMetaIndex
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
outputStream.close();
rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str());
rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str());
}
rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str());
std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl;
return true;
}
template <typename T>
static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
int threadnum, DistCalcMethod distCalcMethod) {
omp_set_num_threads(threadnum);
std::atomic_int numSamples = { 0 };
int D = -1;
int threadbase = 0;
std::vector<std::string> inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
for (std::string inputFileName : inputFileNames)
{
#ifndef _MSC_VER
struct stat stat_buf;
stat(inputFileName.c_str(), &stat_buf);
#else
struct _stat64 stat_buf;
int res = _stat64(inputFileName.c_str(), &stat_buf);
#endif
std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
#pragma omp parallel for
for (int i = 0; i < threadnum; i++) {
ProcessTSVData<T>(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
}
threadbase += threadnum;
}
MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
}
};
}
}
#endif // _SPTAG_COMMON_DATAUTILS_H_

View File

@ -1,216 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DATASET_H_
#define _SPTAG_COMMON_DATASET_H_
#include <fstream>
#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
#include <malloc.h>
#else
#include <mm_malloc.h>
#endif // defined(__GNUC__)
#define ALIGN 32
#define aligned_malloc(a, b) _mm_malloc(a, b)
#define aligned_free(a) _mm_free(a)
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// structure to save Data and Graph
template <typename T>
class Dataset
{
private:
int rows;
int cols;
bool ownData = false;
T* data = nullptr;
std::vector<T> dataIncremental;
public:
Dataset(): rows(0), cols(1) {}
Dataset(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{
Initialize(rows_, cols_, data_, transferOnwership_);
}
~Dataset()
{
if (ownData) aligned_free(data);
}
void Initialize(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
{
rows = rows_;
cols = cols_;
data = data_;
if (data_ == nullptr || !transferOnwership_)
{
ownData = true;
data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN);
if (data_ != nullptr) memcpy(data, data_, rows * cols * sizeof(T));
else std::memset(data, -1, rows * cols * sizeof(T));
}
}
void SetR(int R_)
{
if (R_ >= rows)
dataIncremental.resize((R_ - rows) * cols);
else
{
rows = R_;
dataIncremental.clear();
}
}
inline int R() const { return (int)(rows + dataIncremental.size() / cols); }
inline int C() const { return cols; }
T* operator[](int index)
{
if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
}
const T* operator[](int index) const
{
if (index >= rows) {
return dataIncremental.data() + (size_t)(index - rows)*cols;
}
return data + (size_t)index*cols;
}
void AddBatch(const T* pData, int num)
{
dataIncremental.insert(dataIncremental.end(), pData, pData + num*cols);
}
void AddBatch(int num)
{
dataIncremental.insert(dataIncremental.end(), (size_t)num*cols, T(-1));
}
bool Save(std::string sDataPointsFileName)
{
std::cout << "Save Data To " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false;
int CR = R();
fwrite(&CR, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
T* ptr = data;
int toWrite = rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
}
ptr = dataIncremental.data();
toWrite = CR - rows;
while (toWrite > 0)
{
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
ptr += write * cols;
toWrite -= (int)write;
}
fclose(fp);
std::cout << "Save Data (" << CR << ", " << cols << ") Finish!" << std::endl;
return true;
}
bool Save(void **pDataPointsMemFile, int64_t &len)
{
size_t size = sizeof(int) + sizeof(int) + sizeof(T) * R() *cols;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
int CR = R();
auto header = (int*)mem;
header[0] = CR;
header[1] = cols;
auto body = &mem[8];
memcpy(body, data, sizeof(T) * cols * rows);
body += sizeof(T) * cols * rows;
memcpy(body, dataIncremental.data(), sizeof(T) * cols * (CR - rows));
body += sizeof(T) * cols * (CR - rows);
*pDataPointsMemFile = mem;
len = size;
return true;
}
bool Load(std::string sDataPointsFileName)
{
std::cout << "Load Data From " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
if (fp == NULL) return false;
int R, C;
fread(&R, sizeof(int), 1, fp);
fread(&C, sizeof(int), 1, fp);
Initialize(R, C);
T* ptr = data;
while (R > 0) {
size_t read = fread(ptr, sizeof(T) * C, R, fp);
ptr += read * C;
R -= (int)read;
}
fclose(fp);
std::cout << "Load Data (" << rows << ", " << cols << ") Finish!" << std::endl;
return true;
}
// Functions for loading models from memory mapped files
bool Load(char* pDataPointsMemFile)
{
int R, C;
R = *((int*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int);
C = *((int*)pDataPointsMemFile);
pDataPointsMemFile += sizeof(int);
Initialize(R, C, (T*)pDataPointsMemFile);
return true;
}
bool Refine(const std::vector<int>& indices, std::string sDataPointsFileName)
{
std::cout << "Save Refine Data To " << sDataPointsFileName << std::endl;
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
if (fp == NULL) return false;
int R = (int)(indices.size());
fwrite(&R, sizeof(int), 1, fp);
fwrite(&cols, sizeof(int), 1, fp);
// write point one by one in case for cache miss
for (int i = 0; i < R; i++) {
if (indices[i] < rows)
fwrite(data + (size_t)indices[i] * cols, sizeof(T) * cols, 1, fp);
else
fwrite(dataIncremental.data() + (size_t)(indices[i] - rows) * cols, sizeof(T) * cols, 1, fp);
}
fclose(fp);
std::cout << "Save Refine Data (" << R << ", " << cols << ") Finish!" << std::endl;
return true;
}
};
}
}
#endif // _SPTAG_COMMON_DATASET_H_

View File

@ -1,610 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_DISTANCEUTILS_H_
#define _SPTAG_COMMON_DISTANCEUTILS_H_
#include <immintrin.h>
#include <functional>
#include "CommonUtils.h"
#define SSE
#ifndef _MSC_VER
#define DIFF128 diff128
#define DIFF256 diff256
#else
#define DIFF128 diff128.m128_f32
#define DIFF256 diff256.m256_f32
#endif
namespace SPTAG
{
namespace COMMON
{
class DistanceUtils
{
public:
static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y)
{
__m128i zero = _mm_setzero_si128();
__m128i sign_x = _mm_cmplt_epi8(X, zero);
__m128i sign_y = _mm_cmplt_epi8(Y, zero);
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
__m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
__m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
}
static inline __m128 _mm_sqdf_epi8(__m128i X, __m128i Y)
{
__m128i zero = _mm_setzero_si128();
__m128i sign_x = _mm_cmplt_epi8(X, zero);
__m128i sign_y = _mm_cmplt_epi8(Y, zero);
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
__m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
__m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
__m128i dlo = _mm_sub_epi16(xlo, ylo);
__m128i dhi = _mm_sub_epi16(xhi, yhi);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi)));
}
static inline __m128 _mm_mul_epu8(__m128i X, __m128i Y)
{
__m128i zero = _mm_setzero_si128();
__m128i xlo = _mm_unpacklo_epi8(X, zero);
__m128i xhi = _mm_unpackhi_epi8(X, zero);
__m128i ylo = _mm_unpacklo_epi8(Y, zero);
__m128i yhi = _mm_unpackhi_epi8(Y, zero);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
}
static inline __m128 _mm_sqdf_epu8(__m128i X, __m128i Y)
{
__m128i zero = _mm_setzero_si128();
__m128i xlo = _mm_unpacklo_epi8(X, zero);
__m128i xhi = _mm_unpackhi_epi8(X, zero);
__m128i ylo = _mm_unpacklo_epi8(Y, zero);
__m128i yhi = _mm_unpackhi_epi8(Y, zero);
__m128i dlo = _mm_sub_epi16(xlo, ylo);
__m128i dhi = _mm_sub_epi16(xhi, yhi);
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi)));
}
static inline __m128 _mm_mul_epi16(__m128i X, __m128i Y)
{
return _mm_cvtepi32_ps(_mm_madd_epi16(X, Y));
}
static inline __m128 _mm_sqdf_epi16(__m128i X, __m128i Y)
{
__m128i zero = _mm_setzero_si128();
__m128i sign_x = _mm_cmplt_epi16(X, zero);
__m128i sign_y = _mm_cmplt_epi16(Y, zero);
__m128i xlo = _mm_unpacklo_epi16(X, sign_x);
__m128i xhi = _mm_unpackhi_epi16(X, sign_x);
__m128i ylo = _mm_unpacklo_epi16(Y, sign_y);
__m128i yhi = _mm_unpackhi_epi16(Y, sign_y);
__m128 dlo = _mm_cvtepi32_ps(_mm_sub_epi32(xlo, ylo));
__m128 dhi = _mm_cvtepi32_ps(_mm_sub_epi32(xhi, yhi));
return _mm_add_ps(_mm_mul_ps(dlo, dlo), _mm_mul_ps(dhi, dhi));
}
static inline __m128 _mm_sqdf_ps(__m128 X, __m128 Y)
{
__m128 d = _mm_sub_ps(X, Y);
return _mm_mul_ps(d, d);
}
#if defined(AVX)
static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y)
{
__m256i zero = _mm256_setzero_si256();
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
__m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
__m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
__m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi)));
}
static inline __m256 _mm256_sqdf_epi8(__m256i X, __m256i Y)
{
__m256i zero = _mm256_setzero_si256();
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
__m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
__m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
__m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
__m256i dlo = _mm256_sub_epi16(xlo, ylo);
__m256i dhi = _mm256_sub_epi16(xhi, yhi);
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi)));
}
static inline __m256 _mm256_mul_epu8(__m256i X, __m256i Y)
{
__m256i zero = _mm256_setzero_si256();
__m256i xlo = _mm256_unpacklo_epi8(X, zero);
__m256i xhi = _mm256_unpackhi_epi8(X, zero);
__m256i ylo = _mm256_unpacklo_epi8(Y, zero);
__m256i yhi = _mm256_unpackhi_epi8(Y, zero);
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi)));
}
static inline __m256 _mm256_sqdf_epu8(__m256i X, __m256i Y)
{
__m256i zero = _mm256_setzero_si256();
__m256i xlo = _mm256_unpacklo_epi8(X, zero);
__m256i xhi = _mm256_unpackhi_epi8(X, zero);
__m256i ylo = _mm256_unpacklo_epi8(Y, zero);
__m256i yhi = _mm256_unpackhi_epi8(Y, zero);
__m256i dlo = _mm256_sub_epi16(xlo, ylo);
__m256i dhi = _mm256_sub_epi16(xhi, yhi);
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi)));
}
static inline __m256 _mm256_mul_epi16(__m256i X, __m256i Y)
{
return _mm256_cvtepi32_ps(_mm256_madd_epi16(X, Y));
}
static inline __m256 _mm256_sqdf_epi16(__m256i X, __m256i Y)
{
__m256i zero = _mm256_setzero_si256();
__m256i sign_x = _mm256_cmpgt_epi16(zero, X);
__m256i sign_y = _mm256_cmpgt_epi16(zero, Y);
__m256i xlo = _mm256_unpacklo_epi16(X, sign_x);
__m256i xhi = _mm256_unpackhi_epi16(X, sign_x);
__m256i ylo = _mm256_unpacklo_epi16(Y, sign_y);
__m256i yhi = _mm256_unpackhi_epi16(Y, sign_y);
__m256 dlo = _mm256_cvtepi32_ps(_mm256_sub_epi32(xlo, ylo));
__m256 dhi = _mm256_cvtepi32_ps(_mm256_sub_epi32(xhi, yhi));
return _mm256_add_ps(_mm256_mul_ps(dlo, dlo), _mm256_mul_ps(dhi, dhi));
}
static inline __m256 _mm256_sqdf_ps(__m256 X, __m256 Y)
{
__m256 d = _mm256_sub_ps(X, Y);
return _mm256_mul_ps(d, d);
}
#endif
/*
template<typename T>
static float ComputeL2Distance(const T *pX, const T *pY, int length)
{
float diff = 0;
const T* pEnd1 = pX + length;
while (pX < pEnd1) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
return diff;
}
*/
#define REPEAT(type, ctype, delta, load, exec, acc, result) \
{ \
type c1 = load((ctype *)(pX)); \
type c2 = load((ctype *)(pY)); \
pX += delta; pY += delta; \
result = acc(result, exec(c1, c2)); \
} \
static float ComputeL2Distance(const std::int8_t *pX, const std::int8_t *pY, int length)
{
const std::int8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::int8_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int8_t* pEnd4 = pX + ((length >> 2) << 2);
const std::int8_t* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
}
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epi8, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
#endif
while (pX < pEnd4) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
while (pX < pEnd1) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
return diff;
}
static float ComputeL2Distance(const std::uint8_t *pX, const std::uint8_t *pY, int length)
{
const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4);
const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2);
const std::uint8_t* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
}
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epu8, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
#endif
while (pX < pEnd4) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
while (pX < pEnd1) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
return diff;
}
static float ComputeL2Distance(const std::int16_t *pX, const std::int16_t *pY, int length)
{
const std::int16_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int16_t* pEnd8 = pX + ((length >> 3) << 3);
const std::int16_t* pEnd4 = pX + ((length >> 2) << 2);
const std::int16_t* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
}
while (pX < pEnd8) {
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd16) {
REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_sqdf_epi16, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd8) {
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
#endif
while (pX < pEnd4) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
while (pX < pEnd1) {
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
}
return diff;
}
static float ComputeL2Distance(const float *pX, const float *pY, int length)
{
const float* pEnd16 = pX + ((length >> 4) << 4);
const float* pEnd4 = pX + ((length >> 2) << 2);
const float* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd16)
{
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
}
while (pX < pEnd4)
{
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd16)
{
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256)
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd4)
{
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
while (pX < pEnd4) {
float c1 = (*pX++) - (*pY++); diff += c1 * c1;
c1 = (*pX++) - (*pY++); diff += c1 * c1;
c1 = (*pX++) - (*pY++); diff += c1 * c1;
c1 = (*pX++) - (*pY++); diff += c1 * c1;
}
#endif
while (pX < pEnd1) {
float c1 = (*pX++) - (*pY++); diff += c1 * c1;
}
return diff;
}
/*
template<typename T>
static float ComputeCosineDistance(const T *pX, const T *pY, int length) {
float diff = 0;
const T* pEnd1 = pX + length;
while (pX < pEnd1) diff += (*pX++) * (*pY++);
return 1 - diff;
}
*/
static float ComputeCosineDistance(const std::int8_t *pX, const std::int8_t *pY, int length) {
const std::int8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::int8_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int8_t* pEnd4 = pX + ((length >> 2) << 2);
const std::int8_t* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
}
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epi8, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
#endif
while (pX < pEnd4)
{
float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
}
while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++));
return 16129 - diff;
}
static float ComputeCosineDistance(const std::uint8_t *pX, const std::uint8_t *pY, int length) {
const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5);
const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4);
const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2);
const std::uint8_t* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
}
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd32) {
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epu8, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
#endif
while (pX < pEnd4)
{
float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
}
while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++));
return 65025 - diff;
}
static float ComputeCosineDistance(const std::int16_t *pX, const std::int16_t *pY, int length) {
const std::int16_t* pEnd16 = pX + ((length >> 4) << 4);
const std::int16_t* pEnd8 = pX + ((length >> 3) << 3);
const std::int16_t* pEnd4 = pX + ((length >> 2) << 2);
const std::int16_t* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd16) {
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
}
while (pX < pEnd8) {
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd16) {
REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_mul_epi16, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd8) {
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
#endif
while (pX < pEnd4)
{
float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
}
while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++));
return 1073676289 - diff;
}
static float ComputeCosineDistance(const float *pX, const float *pY, int length) {
const float* pEnd16 = pX + ((length >> 4) << 4);
const float* pEnd4 = pX + ((length >> 2) << 2);
const float* pEnd1 = pX + length;
#if defined(SSE)
__m128 diff128 = _mm_setzero_ps();
while (pX < pEnd16)
{
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
}
while (pX < pEnd4)
{
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#elif defined(AVX)
__m256 diff256 = _mm256_setzero_ps();
while (pX < pEnd16)
{
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256)
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256)
}
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
while (pX < pEnd4)
{
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
}
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
#else
float diff = 0;
while (pX < pEnd4)
{
float c1 = (*pX++) * (*pY++); diff += c1;
c1 = (*pX++) * (*pY++); diff += c1;
c1 = (*pX++) * (*pY++); diff += c1;
c1 = (*pX++) * (*pY++); diff += c1;
}
#endif
while (pX < pEnd1) diff += (*pX++) * (*pY++);
return 1 - diff;
}
template<typename T>
static inline float ComputeDistance(const T *p1, const T *p2, int length, SPTAG::DistCalcMethod distCalcMethod)
{
if (distCalcMethod == SPTAG::DistCalcMethod::L2)
return ComputeL2Distance(p1, p2, length);
return ComputeCosineDistance(p1, p2, length);
}
static inline float ConvertCosineSimilarityToDistance(float cs)
{
// Cosine similarity is in [-1, 1], the higher the value, the closer are the two vectors.
// However, the tree is built and searched based on "distance" between two vectors, that's >=0. The smaller the value, the closer are the two vectors.
// So we do a linear conversion from a cosine similarity to a distance value.
return 1 - cs; //[1, 3]
}
static inline float ConvertDistanceBackToCosineSimilarity(float d)
{
return 1 - d;
}
};
template<typename T>
float (*DistanceCalcSelector(SPTAG::DistCalcMethod p_method)) (const T*, const T*, int)
{
switch (p_method)
{
case SPTAG::DistCalcMethod::Cosine:
return &(DistanceUtils::ComputeCosineDistance);
case SPTAG::DistCalcMethod::L2:
return &(DistanceUtils::ComputeL2Distance);
default:
break;
}
return nullptr;
}
}
}
#endif // _SPTAG_COMMON_DISTANCEUTILS_H_

View File

@ -1,51 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_
#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_
#include <vector>
#include <mutex>
#include <memory>
namespace SPTAG
{
namespace COMMON
{
class FineGrainedLock {
public:
FineGrainedLock() {}
~FineGrainedLock() {
for (int i = 0; i < locks.size(); i++)
locks[i].reset();
locks.clear();
}
void resize(int n) {
int current = (int)locks.size();
if (current <= n) {
locks.resize(n);
for (int i = current; i < n; i++)
locks[i].reset(new std::mutex);
}
else {
for (int i = n; i < current; i++)
locks[i].reset();
locks.resize(n);
}
}
std::mutex& operator[](int idx) {
return *locks[idx];
}
const std::mutex& operator[](int idx) const {
return *locks[idx];
}
private:
std::vector<std::shared_ptr<std::mutex>> locks;
};
}
}
#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_

View File

@ -1,105 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_HEAP_H_
#define _SPTAG_COMMON_HEAP_H_
namespace SPTAG
{
namespace COMMON
{
// priority queue
template <typename T>
class Heap {
public:
Heap() : heap(nullptr), length(0), count(0) {}
Heap(int size) { Resize(size); }
void Resize(int size)
{
length = size;
heap.reset(new T[length + 1]); // heap uses 1-based indexing
count = 0;
lastlevel = int(pow(2.0, floor(log2(size))));
}
~Heap() {}
inline int size() { return count; }
inline bool empty() { return count == 0; }
inline void clear() { count = 0; }
inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; }
// Insert a new element in the heap.
void insert(T value)
{
/* If heap is full, then return without adding this element. */
int loc;
if (count == length) {
int maxi = lastlevel;
for (int i = lastlevel + 1; i <= length; i++)
if (heap[maxi] < heap[i]) maxi = i;
if (value > heap[maxi]) return;
loc = maxi;
}
else {
loc = ++(count); /* Remember 1-based indexing. */
}
/* Keep moving parents down until a place is found for this node. */
int par = (loc >> 1); /* Location of parent. */
while (par > 0 && value < heap[par]) {
heap[loc] = heap[par]; /* Move parent down to loc. */
loc = par;
par >>= 1;
}
/* Insert the element at the determined location. */
heap[loc] = value;
}
// Returns the node of minimum value from the heap (top of the heap).
bool pop(T& value)
{
if (count == 0) return false;
/* Switch first node with last. */
value = heap[1];
std::swap(heap[1], heap[count]);
count--;
heapify(); /* Move new node 1 to right position. */
return true; /* Return old last node. */
}
T& pop()
{
if (count == 0) return heap[0];
/* Switch first node with last. */
std::swap(heap[1], heap[count]);
count--;
heapify(); /* Move new node 1 to right position. */
return heap[count + 1]; /* Return old last node. */
}
private:
// Storage array for the heap.
// Type T must be comparable.
std::unique_ptr<T[]> heap;
int length;
int count; // Number of element in the heap
int lastlevel;
// Reorganizes the heap (a parent is smaller than its children) starting with a node.
void heapify()
{
int parent = 1, next = 2;
while (next < count) {
if (heap[next] > heap[next + 1]) next++;
if (heap[next] < heap[parent]) {
std::swap(heap[parent], heap[next]);
parent = next;
next <<= 1;
}
else break;
}
if (next == count && heap[next] < heap[parent]) std::swap(heap[parent], heap[next]);
}
};
}
}
#endif // _SPTAG_COMMON_HEAP_H_

View File

@ -1,358 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_KDTREE_H_
#define _SPTAG_COMMON_KDTREE_H_
#include <iostream>
#include <vector>
#include <string>
#include "../VectorIndex.h"
#include "CommonUtils.h"
#include "QueryResultSet.h"
#include "WorkSpace.h"
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
namespace SPTAG
{
namespace COMMON
{
// node type for storing KDT
struct KDTNode
{
int left;
int right;
short split_dim;
float split_value;
};
class KDTree
{
public:
KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000) {}
KDTree(KDTree& other) : m_iTreeNumber(other.m_iTreeNumber),
m_numTopDimensionKDTSplit(other.m_numTopDimensionKDTSplit),
m_iSamples(other.m_iSamples) {}
~KDTree() {}
inline const KDTNode& operator[](int index) const { return m_pTreeRoots[index]; }
inline KDTNode& operator[](int index) { return m_pTreeRoots[index]; }
inline int size() const { return (int)m_pTreeRoots.size(); }
template <typename T>
void BuildTrees(VectorIndex* p_index, std::vector<int>* indices = nullptr)
{
std::vector<int> localindices;
if (indices == nullptr) {
localindices.resize(p_index->GetNumSamples());
for (int i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i;
}
else {
localindices.assign(indices->begin(), indices->end());
}
m_pTreeRoots.resize(m_iTreeNumber * localindices.size());
m_pTreeStart.resize(m_iTreeNumber, 0);
#pragma omp parallel for
for (int i = 0; i < m_iTreeNumber; i++)
{
Sleep(i * 100); std::srand(clock());
std::vector<int> pindices(localindices.begin(), localindices.end());
std::random_shuffle(pindices.begin(), pindices.end());
m_pTreeStart[i] = i * (int)pindices.size();
std::cout << "Start to build KDTree " << i + 1 << std::endl;
int iTreeSize = m_pTreeStart[i];
DivideTree<T>(p_index, pindices, 0, (int)pindices.size() - 1, m_pTreeStart[i], iTreeSize);
std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl;
}
}
bool SaveTrees(void **pKDTMemFile, int64_t &len) const
{
int treeNodeSize = (int)m_pTreeRoots.size();
size_t size = sizeof(int) +
sizeof(int) * m_iTreeNumber +
sizeof(int) +
sizeof(KDTNode) * treeNodeSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iTreeNumber;
ptr += sizeof(int);
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
ptr += sizeof(int) * m_iTreeNumber;
*(int*)ptr = treeNodeSize;
ptr += sizeof(int);
memcpy(ptr, m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
*pKDTMemFile = mem;
len = size;
return true;
}
bool SaveTrees(std::string sTreeFileName) const
{
std::cout << "Save KDT to " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iTreeNumber, sizeof(int), 1, fp);
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize = (int)m_pTreeRoots.size();
fwrite(&treeNodeSize, sizeof(int), 1, fp);
fwrite(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
bool LoadTrees(char* pKDTMemFile)
{
m_iTreeNumber = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int);
m_pTreeStart.resize(m_iTreeNumber);
memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(int) * m_iTreeNumber);
pKDTMemFile += sizeof(int)*m_iTreeNumber;
int treeNodeSize = *((int*)pKDTMemFile);
pKDTMemFile += sizeof(int);
m_pTreeRoots.resize(treeNodeSize);
memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize);
return true;
}
bool LoadTrees(std::string sTreeFileName)
{
std::cout << "Load KDT From " << sTreeFileName << std::endl;
FILE *fp = fopen(sTreeFileName.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iTreeNumber, sizeof(int), 1, fp);
m_pTreeStart.resize(m_iTreeNumber);
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
int treeNodeSize;
fread(&treeNodeSize, sizeof(int), 1, fp);
m_pTreeRoots.resize(treeNodeSize);
fread(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
fclose(fp);
std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
return true;
}
template <typename T>
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{
for (char i = 0; i < m_iTreeNumber; i++) {
KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0);
}
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
{
auto& tcell = p_space.m_SPTQueue.pop();
if (p_query.worstDist() < tcell.distance) break;
KDTSearch(p_index, p_query, p_space, tcell.node, true, tcell.distance);
}
}
template <typename T>
void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
{
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
{
auto& tcell = p_space.m_SPTQueue.pop();
KDTSearch(p_index, p_query, p_space, tcell.node, false, tcell.distance);
}
}
private:
template <typename T>
void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
COMMON::WorkSpace& p_space, const int node, const bool isInit, const float distBound) const {
if (node < 0)
{
int index = -node - 1;
if (index >= p_index->GetNumSamples()) return;
#ifdef PREFETCH
const char* data = (const char *)(p_index->GetSample(index));
_mm_prefetch(data, _MM_HINT_T0);
_mm_prefetch(data + 64, _MM_HINT_T0);
#endif
if (p_space.CheckAndSet(index)) return;
++p_space.m_iNumberOfTreeCheckedLeaves;
++p_space.m_iNumberOfCheckedLeaves;
p_space.m_NGQueue.insert(COMMON::HeapCell(index, p_index->ComputeDistance((const void*)p_query.GetTarget(), (const void*)data)));
return;
}
auto& tnode = m_pTreeRoots[node];
float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value;
float distanceBound = distBound + diff * diff;
int otherChild, bestChild;
if (diff < 0)
{
bestChild = tnode.left;
otherChild = tnode.right;
}
else
{
otherChild = tnode.left;
bestChild = tnode.right;
}
if (!isInit || distanceBound < p_query.worstDist())
{
p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound));
}
KDTSearch(p_index, p_query, p_space, bestChild, isInit, distBound);
}
template <typename T>
void DivideTree(VectorIndex* p_index, std::vector<int>& indices, int first, int last,
int index, int &iTreeSize) {
ChooseDivision<T>(p_index, m_pTreeRoots[index], indices, first, last);
int i = Subdivide<T>(p_index, m_pTreeRoots[index], indices, first, last);
if (i - 1 <= first)
{
m_pTreeRoots[index].left = -indices[first] - 1;
}
else
{
iTreeSize++;
m_pTreeRoots[index].left = iTreeSize;
DivideTree<T>(p_index, indices, first, i - 1, iTreeSize, iTreeSize);
}
if (last == i)
{
m_pTreeRoots[index].right = -indices[last] - 1;
}
else
{
iTreeSize++;
m_pTreeRoots[index].right = iTreeSize;
DivideTree<T>(p_index, indices, i, last, iTreeSize, iTreeSize);
}
}
template <typename T>
void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector<int>& indices, const int first, const int last)
{
std::vector<float> meanValues(p_index->GetFeatureDim(), 0);
std::vector<float> varianceValues(p_index->GetFeatureDim(), 0);
int end = min(first + m_iSamples, last);
int count = end - first + 1;
// calculate the mean of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
meanValues[k] += v[k];
}
}
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
meanValues[k] /= count;
}
// calculate the variance of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)p_index->GetSample(indices[j]);
for (int k = 0; k < p_index->GetFeatureDim(); k++)
{
float dist = v[k] - meanValues[k];
varianceValues[k] += dist*dist;
}
}
// choose the split dimension as one of the dimension inside TOP_DIM maximum variance
node.split_dim = SelectDivisionDimension(varianceValues);
// determine the threshold
node.split_value = meanValues[node.split_dim];
}
int SelectDivisionDimension(const std::vector<float>& varianceValues) const
{
// Record the top maximum variances
std::vector<int> topind(m_numTopDimensionKDTSplit);
int num = 0;
// order the variances
for (int i = 0; i < varianceValues.size(); i++)
{
if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]])
{
if (num < m_numTopDimensionKDTSplit)
{
topind[num++] = i;
}
else
{
topind[num - 1] = i;
}
int j = num - 1;
// order the TOP_DIM variances
while (j > 0 && varianceValues[topind[j]] > varianceValues[topind[j - 1]])
{
std::swap(topind[j], topind[j - 1]);
j--;
}
}
}
// randomly choose a dimension from TOP_DIM
return topind[COMMON::Utils::rand_int(num)];
}
template <typename T>
int Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector<int>& indices, const int first, const int last) const
{
int i = first;
int j = last;
// decide which child one point belongs
while (i <= j)
{
int ind = indices[i];
const T* v = (const T*)p_index->GetSample(ind);
float val = v[node.split_dim];
if (val < node.split_value)
{
i++;
}
else
{
std::swap(indices[i], indices[j]);
j--;
}
}
// if all the points in the node are equal,equally split the node into 2
if ((i == first) || (i == last + 1))
{
i = (first + last + 1) / 2;
}
return i;
}
private:
std::vector<int> m_pTreeStart;
std::vector<KDTNode> m_pTreeRoots;
public:
int m_iTreeNumber, m_numTopDimensionKDTSplit, m_iSamples;
};
}
}
#endif

View File

@ -1,436 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_NG_H_
#define _SPTAG_COMMON_NG_H_
#include "../VectorIndex.h"
#include "CommonUtils.h"
#include "Dataset.h"
#include "FineGrainedLock.h"
#include "QueryResultSet.h"
namespace SPTAG
{
namespace COMMON
{
class NeighborhoodGraph
{
public:
NeighborhoodGraph(): m_iTPTNumber(32),
m_iTPTLeafSize(2000),
m_iSamples(1000),
m_numTopDimensionTPTSplit(5),
m_iNeighborhoodSize(32),
m_iNeighborhoodScale(2),
m_iCEFScale(2),
m_iRefineIter(0),
m_iCEF(1000),
m_iMaxCheckForRefineGraph(10000) {}
~NeighborhoodGraph() {}
virtual void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist) = 0;
virtual void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) = 0;
virtual float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr) = 0;
template <typename T>
void BuildGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr)
{
std::cout << "build RNG graph!" << std::endl;
m_iGraphSize = index->GetNumSamples();
m_iNeighborhoodSize = m_iNeighborhoodSize * m_iNeighborhoodScale;
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
m_dataUpdateLock.resize(m_iGraphSize);
if (m_iGraphSize < 1000) {
RefineGraph<T>(index, idmap);
std::cout << "Build RNG Graph end!" << std::endl;
return;
}
{
COMMON::Dataset<float> NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize);
std::vector<std::vector<int>> TptreeDataIndices(m_iTPTNumber, std::vector<int>(m_iGraphSize));
std::vector<std::vector<std::pair<int, int>>> TptreeLeafNodes(m_iTPTNumber, std::vector<std::pair<int, int>>());
for (int i = 0; i < m_iGraphSize; i++)
for (int j = 0; j < m_iNeighborhoodSize; j++)
(NeighborhoodDists)[i][j] = MaxDist;
std::cout << "Parallel TpTree Partition begin " << std::endl;
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iTPTNumber; i++)
{
Sleep(i * 100); std::srand(clock());
for (int j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j;
std::random_shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end());
PartitionByTptree<T>(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]);
std::cout << "Finish Getting Leaves for Tree " << i << std::endl;
}
std::cout << "Parallel TpTree Partition done" << std::endl;
for (int i = 0; i < m_iTPTNumber; i++)
{
#pragma omp parallel for schedule(dynamic)
for (int j = 0; j < TptreeLeafNodes[i].size(); j++)
{
int start_index = TptreeLeafNodes[i][j].first;
int end_index = TptreeLeafNodes[i][j].second;
if (omp_get_thread_num() == 0) std::cout << "\rProcessing Tree " << i << ' ' << j * 100 / TptreeLeafNodes[i].size() << '%';
for (int x = start_index; x < end_index; x++)
{
for (int y = x + 1; y <= end_index; y++)
{
int p1 = TptreeDataIndices[i][x];
int p2 = TptreeDataIndices[i][y];
float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2));
if (idmap != nullptr) {
p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1);
p2 = (idmap->find(p2) == idmap->end()) ? p2 : idmap->at(p2);
}
COMMON::Utils::AddNeighbor(p2, dist, (m_pNeighborhoodGraph)[p1], (NeighborhoodDists)[p1], m_iNeighborhoodSize);
COMMON::Utils::AddNeighbor(p1, dist, (m_pNeighborhoodGraph)[p2], (NeighborhoodDists)[p2], m_iNeighborhoodSize);
}
}
}
TptreeDataIndices[i].clear();
TptreeLeafNodes[i].clear();
std::cout << std::endl;
}
TptreeDataIndices.clear();
TptreeLeafNodes.clear();
}
if (m_iMaxCheckForRefineGraph > 0) {
RefineGraph<T>(index, idmap);
}
}
template <typename T>
void RefineGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr)
{
m_iCEF *= m_iCEFScale;
m_iMaxCheckForRefineGraph *= m_iCEFScale;
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iGraphSize; i++)
{
RefineNode<T>(index, i, false);
if (i % 1000 == 0) std::cout << "\rRefine 1 " << (i * 100 / m_iGraphSize) << "%";
}
std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl;
m_iCEF /= m_iCEFScale;
m_iMaxCheckForRefineGraph /= m_iCEFScale;
m_iNeighborhoodSize /= m_iNeighborhoodScale;
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < m_iGraphSize; i++)
{
RefineNode<T>(index, i, false);
if (i % 1000 == 0) std::cout << "\rRefine 2 " << (i * 100 / m_iGraphSize) << "%";
}
std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl;
if (idmap != nullptr) {
for (auto iter = idmap->begin(); iter != idmap->end(); iter++)
if (iter->first < 0)
{
m_pNeighborhoodGraph[-1 - iter->first][m_iNeighborhoodSize - 1] = -2 - iter->second;
}
}
}
template <typename T>
ErrorCode RefineGraph(VectorIndex* index, std::vector<int>& indices, std::vector<int>& reverseIndices,
std::string graphFileName, const std::unordered_map<int, int>* idmap = nullptr)
{
int R = (int)indices.size();
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < R; i++)
{
RefineNode<T>(index, indices[i], false);
int* nodes = m_pNeighborhoodGraph[indices[i]];
for (int j = 0; j < m_iNeighborhoodSize; j++)
{
if (nodes[j] < 0) nodes[j] = -1;
else nodes[j] = reverseIndices[nodes[j]];
}
if (idmap == nullptr || idmap->find(-1 - indices[i]) == idmap->end()) continue;
nodes[m_iNeighborhoodSize - 1] = -2 - idmap->at(-1 - indices[i]);
}
std::ofstream graphOut(graphFileName, std::ios::binary);
if (!graphOut.is_open()) return ErrorCode::FailedCreateFile;
graphOut.write((char*)&R, sizeof(int));
graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int));
for (int i = 0; i < R; i++) {
graphOut.write((char*)m_pNeighborhoodGraph[indices[i]], sizeof(int) * m_iNeighborhoodSize);
}
graphOut.close();
return ErrorCode::Success;
}
template <typename T>
void RefineNode(VectorIndex* index, const int node, bool updateNeighbors)
{
COMMON::QueryResultSet<T> query((const T*)index->GetSample(node), m_iCEF + 1);
index->SearchIndex(query);
RebuildNeighbors(index, node, m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF + 1);
if (updateNeighbors) {
// update neighbors
for (int j = 0; j <= m_iCEF; j++)
{
BasicResult* item = query.GetResult(j);
if (item->VID < 0) break;
if (item->VID == node) continue;
std::lock_guard<std::mutex> lock(m_dataUpdateLock[item->VID]);
InsertNeighbors(index, item->VID, node, item->Dist);
}
}
}
template <typename T>
void PartitionByTptree(VectorIndex* index, std::vector<int>& indices, const int first, const int last,
std::vector<std::pair<int, int>> & leaves)
{
if (last - first <= m_iTPTLeafSize)
{
leaves.push_back(std::make_pair(first, last));
}
else
{
std::vector<float> Mean(index->GetFeatureDim(), 0);
int iIteration = 100;
int end = min(first + m_iSamples, last);
int count = end - first + 1;
// calculate the mean of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)index->GetSample(indices[j]);
for (int k = 0; k < index->GetFeatureDim(); k++)
{
Mean[k] += v[k];
}
}
for (int k = 0; k < index->GetFeatureDim(); k++)
{
Mean[k] /= count;
}
std::vector<BasicResult> Variance;
Variance.reserve(index->GetFeatureDim());
for (int j = 0; j < index->GetFeatureDim(); j++)
{
Variance.push_back(BasicResult(j, 0));
}
// calculate the variance of each dimension
for (int j = first; j <= end; j++)
{
const T* v = (const T*)index->GetSample(indices[j]);
for (int k = 0; k < index->GetFeatureDim(); k++)
{
float dist = v[k] - Mean[k];
Variance[k].Dist += dist*dist;
}
}
std::sort(Variance.begin(), Variance.end(), COMMON::Compare);
std::vector<int> indexs(m_numTopDimensionTPTSplit);
std::vector<float> weight(m_numTopDimensionTPTSplit), bestweight(m_numTopDimensionTPTSplit);
float bestvariance = Variance[index->GetFeatureDim() - 1].Dist;
for (int i = 0; i < m_numTopDimensionTPTSplit; i++)
{
indexs[i] = Variance[index->GetFeatureDim() - 1 - i].VID;
bestweight[i] = 0;
}
bestweight[0] = 1;
float bestmean = Mean[indexs[0]];
std::vector<float> Val(count);
for (int i = 0; i < iIteration; i++)
{
float sumweight = 0;
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
{
weight[j] = float(rand() % 10000) / 5000.0f - 1.0f;
sumweight += weight[j] * weight[j];
}
sumweight = sqrt(sumweight);
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
{
weight[j] /= sumweight;
}
float mean = 0;
for (int j = 0; j < count; j++)
{
Val[j] = 0;
const T* v = (const T*)index->GetSample(indices[first + j]);
for (int k = 0; k < m_numTopDimensionTPTSplit; k++)
{
Val[j] += weight[k] * v[indexs[k]];
}
mean += Val[j];
}
mean /= count;
float var = 0;
for (int j = 0; j < count; j++)
{
float dist = Val[j] - mean;
var += dist * dist;
}
if (var > bestvariance)
{
bestvariance = var;
bestmean = mean;
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
{
bestweight[j] = weight[j];
}
}
}
int i = first;
int j = last;
// decide which child one point belongs
while (i <= j)
{
float val = 0;
const T* v = (const T*)index->GetSample(indices[i]);
for (int k = 0; k < m_numTopDimensionTPTSplit; k++)
{
val += bestweight[k] * v[indexs[k]];
}
if (val < bestmean)
{
i++;
}
else
{
std::swap(indices[i], indices[j]);
j--;
}
}
// if all the points in the node are equal,equally split the node into 2
if ((i == first) || (i == last + 1))
{
i = (first + last + 1) / 2;
}
Mean.clear();
Variance.clear();
Val.clear();
indexs.clear();
weight.clear();
bestweight.clear();
PartitionByTptree<T>(index, indices, first, i - 1, leaves);
PartitionByTptree<T>(index, indices, i, last, leaves);
}
}
bool LoadGraph(std::string sGraphFilename)
{
std::cout << "Load Graph From " << sGraphFilename << std::endl;
FILE * fp = fopen(sGraphFilename.c_str(), "rb");
if (fp == NULL) return false;
fread(&m_iGraphSize, sizeof(int), 1, fp);
fread(&m_iNeighborhoodSize, sizeof(int), 1, fp);
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
m_dataUpdateLock.resize(m_iGraphSize);
for (int i = 0; i < m_iGraphSize; i++)
{
fread((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
}
fclose(fp);
std::cout << "Load Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
return true;
}
bool LoadGraphFromMemory(char* pGraphMemFile)
{
m_iGraphSize = *((int*)pGraphMemFile);
pGraphMemFile += sizeof(int);
m_iNeighborhoodSize = *((int*)pGraphMemFile);
pGraphMemFile += sizeof(int);
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize, (int*)pGraphMemFile);
m_dataUpdateLock.resize(m_iGraphSize);
return true;
}
bool SaveGraph(std::string sGraphFilename) const
{
std::cout << "Save Graph To " << sGraphFilename << std::endl;
FILE *fp = fopen(sGraphFilename.c_str(), "wb");
if (fp == NULL) return false;
fwrite(&m_iGraphSize, sizeof(int), 1, fp);
fwrite(&m_iNeighborhoodSize, sizeof(int), 1, fp);
for (int i = 0; i < m_iGraphSize; i++)
{
fwrite((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
}
fclose(fp);
std::cout << "Save Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
return true;
}
bool SaveGraphToMemory(void **pGraphMemFile, int64_t &len) {
size_t size = sizeof(int) + sizeof(int) + sizeof(int) * m_iNeighborhoodSize * m_iGraphSize;
char *mem = (char*)malloc(size);
if (mem == NULL) return false;
auto ptr = mem;
*(int*)ptr = m_iGraphSize;
ptr += sizeof(int);
*(int*)ptr = m_iNeighborhoodSize;
ptr += sizeof(int);
for (int i = 0; i < m_iGraphSize; i++)
{
memcpy(ptr, (m_pNeighborhoodGraph)[i], sizeof(int) * m_iNeighborhoodSize);
ptr += sizeof(int) * m_iNeighborhoodSize;
}
*pGraphMemFile = mem;
len = size;
return true;
}
inline void AddBatch(int num) { m_pNeighborhoodGraph.AddBatch(num); m_iGraphSize += num; m_dataUpdateLock.resize(m_iGraphSize); }
inline int* operator[](int index) { return m_pNeighborhoodGraph[index]; }
inline const int* operator[](int index) const { return m_pNeighborhoodGraph[index]; }
inline void SetR(int rows) { m_pNeighborhoodGraph.SetR(rows); m_iGraphSize = rows; m_dataUpdateLock.resize(m_iGraphSize); }
inline int R() const { return m_iGraphSize; }
static std::shared_ptr<NeighborhoodGraph> CreateInstance(std::string type);
protected:
// Graph structure
int m_iGraphSize;
COMMON::Dataset<int> m_pNeighborhoodGraph;
COMMON::FineGrainedLock m_dataUpdateLock; // protect one row of the graph
public:
int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit;
int m_iNeighborhoodSize, m_iNeighborhoodScale, m_iCEFScale, m_iRefineIter, m_iCEF, m_iMaxCheckForRefineGraph;
};
}
}
#endif

View File

@ -1,96 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_QUERYRESULTSET_H_
#define _SPTAG_COMMON_QUERYRESULTSET_H_
#include "../SearchQuery.h"
namespace SPTAG
{
namespace COMMON
{
inline bool operator < (const BasicResult& lhs, const BasicResult& rhs)
{
return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
}
inline bool Compare(const BasicResult& lhs, const BasicResult& rhs)
{
return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
}
// Space to save temporary answer, similar with TopKCache
template<typename T>
class QueryResultSet : public QueryResult
{
public:
QueryResultSet(const T* _target, int _K) : QueryResult(_target, _K, false)
{
}
QueryResultSet(const QueryResultSet& other) : QueryResult(other)
{
}
inline void SetTarget(const T *p_target)
{
m_target = p_target;
}
inline const T* GetTarget() const
{
return reinterpret_cast<const T*>(m_target);
}
inline float worstDist() const
{
return m_results[0].Dist;
}
bool AddPoint(const int index, float dist)
{
if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID))
{
m_results[0].VID = index;
m_results[0].Dist = dist;
Heapify(m_resultNum);
return true;
}
return false;
}
inline void SortResult()
{
for (int i = m_resultNum - 1; i >= 0; i--)
{
std::swap(m_results[0], m_results[i]);
Heapify(i);
}
}
private:
void Heapify(int count)
{
int parent = 0, next = 1, maxidx = count - 1;
while (next < maxidx)
{
if (m_results[next] < m_results[next + 1]) next++;
if (m_results[parent] < m_results[next])
{
std::swap(m_results[next], m_results[parent]);
parent = next;
next = (parent << 1) + 1;
}
else break;
}
if (next == maxidx && m_results[parent] < m_results[next]) std::swap(m_results[parent], m_results[next]);
}
};
}
}
#endif // _SPTAG_COMMON_QUERYRESULTSET_H_

View File

@ -1,123 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_RNG_H_
#define _SPTAG_COMMON_RNG_H_
#include "NeighborhoodGraph.h"
namespace SPTAG
{
namespace COMMON
{
class RelativeNeighborhoodGraph: public NeighborhoodGraph
{
public:
void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) {
int count = 0;
for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) {
const BasicResult& item = queryResults[j];
if (item.VID < 0) break;
if (item.VID == node) continue;
bool good = true;
for (int k = 0; k < count; k++) {
if (index->ComputeDistance(index->GetSample(nodes[k]), index->GetSample(item.VID)) <= item.Dist) {
good = false;
break;
}
}
if (good) nodes[count++] = item.VID;
}
for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1;
}
void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist)
{
int* nodes = m_pNeighborhoodGraph[node];
for (int k = 0; k < m_iNeighborhoodSize; k++)
{
int tmpNode = nodes[k];
if (tmpNode < -1) continue;
if (tmpNode < 0)
{
bool good = true;
for (int t = 0; t < k; t++) {
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
good = false;
break;
}
}
if (good) {
nodes[k] = insertNode;
}
break;
}
float tmpDist = index->ComputeDistance(index->GetSample(node), index->GetSample(tmpNode));
if (insertDist < tmpDist || (insertDist == tmpDist && insertNode < tmpNode))
{
bool good = true;
for (int t = 0; t < k; t++) {
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
good = false;
break;
}
}
if (good) {
nodes[k] = insertNode;
insertNode = tmpNode;
insertDist = tmpDist;
}
else {
break;
}
}
}
}
float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr)
{
int* correct = new int[samples];
#pragma omp parallel for schedule(dynamic)
for (int i = 0; i < samples; i++)
{
int x = COMMON::Utils::rand_int(m_iGraphSize);
//int x = i;
COMMON::QueryResultSet<void> query(nullptr, m_iCEF);
for (int y = 0; y < m_iGraphSize; y++)
{
if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue;
float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y));
query.AddPoint(y, dist);
}
query.SortResult();
int * exact_rng = new int[m_iNeighborhoodSize];
RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF);
correct[i] = 0;
for (int j = 0; j < m_iNeighborhoodSize; j++) {
if (exact_rng[j] == -1) {
correct[i] += m_iNeighborhoodSize - j;
break;
}
for (int k = 0; k < m_iNeighborhoodSize; k++)
if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) {
correct[i]++;
break;
}
}
delete[] exact_rng;
}
float acc = 0;
for (int i = 0; i < samples; i++) acc += float(correct[i]);
acc = acc / samples / m_iNeighborhoodSize;
delete[] correct;
return acc;
}
};
}
}
#endif

View File

@ -1,185 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_WORKSPACE_H_
#define _SPTAG_COMMON_WORKSPACE_H_
#include "CommonUtils.h"
#include "Heap.h"
namespace SPTAG
{
namespace COMMON
{
// node type in the priority queue
struct HeapCell
{
int node;
float distance;
HeapCell(int _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {}
inline bool operator < (const HeapCell& rhs)
{
return distance < rhs.distance;
}
inline bool operator > (const HeapCell& rhs)
{
return distance > rhs.distance;
}
};
class OptHashPosVector
{
protected:
// Max loop number in one hash block.
static const int m_maxLoop = 8;
// Max pool size.
static const int m_poolSize = 8191;
// Could we use the second hash block.
bool m_secondHash;
// Record 2 hash tables.
// [0~m_poolSize + 1) is the first block.
// [m_poolSize + 1, 2*(m_poolSize + 1)) is the second block;
int m_hashTable[(m_poolSize + 1) * 2];
inline unsigned hash_func2(int idx, int loop)
{
return ((unsigned)idx + loop) & m_poolSize;
}
inline unsigned hash_func(unsigned idx)
{
return ((unsigned)(idx * 99991) + _rotl(idx, 2) + 101) & m_poolSize;
}
public:
OptHashPosVector() {}
~OptHashPosVector() {}
void Init(int size)
{
m_secondHash = true;
clear();
}
void clear()
{
if (!m_secondHash)
{
// Clear first block.
memset(&m_hashTable[0], 0, sizeof(int)*(m_poolSize + 1));
}
else
{
// Clear all blocks.
memset(&m_hashTable[0], 0, 2 * sizeof(int) * (m_poolSize + 1));
m_secondHash = false;
}
}
inline bool CheckAndSet(int idx)
{
// Inner Index is begin from 1
return _CheckAndSet(&m_hashTable[0], idx + 1) == 0;
}
inline int _CheckAndSet(int* hashTable, int idx)
{
unsigned index, loop;
// Get first hash position.
index = hash_func(idx);
for (loop = 0; loop < m_maxLoop; ++loop)
{
if (!hashTable[index])
{
// index first match and record it.
hashTable[index] = idx;
return 1;
}
if (hashTable[index] == idx)
{
// Hit this item in hash table.
return 0;
}
// Get next hash position.
index = hash_func2(index, loop);
}
if (hashTable == &m_hashTable[0])
{
// Use second hash block.
m_secondHash = true;
return _CheckAndSet(&m_hashTable[m_poolSize + 1], idx);
}
// Do not include this item.
return -1;
}
};
// Variables for each single NN search
struct WorkSpace
{
void Initialize(int maxCheck, int dataSize)
{
nodeCheckStatus.Init(dataSize);
m_SPTQueue.Resize(maxCheck * 10);
m_NGQueue.Resize(maxCheck * 30);
m_iNumberOfTreeCheckedLeaves = 0;
m_iNumberOfCheckedLeaves = 0;
m_iContinuousLimit = maxCheck / 64;
m_iMaxCheck = maxCheck;
m_iNumOfContinuousNoBetterPropagation = 0;
}
void Reset(int maxCheck)
{
nodeCheckStatus.clear();
m_SPTQueue.clear();
m_NGQueue.clear();
m_iNumberOfTreeCheckedLeaves = 0;
m_iNumberOfCheckedLeaves = 0;
m_iContinuousLimit = maxCheck / 64;
m_iMaxCheck = maxCheck;
m_iNumOfContinuousNoBetterPropagation = 0;
}
inline bool CheckAndSet(int idx)
{
return nodeCheckStatus.CheckAndSet(idx);
}
OptHashPosVector nodeCheckStatus;
//OptHashPosVector nodeCheckStatus;
// counter for dynamic pivoting
int m_iNumOfContinuousNoBetterPropagation;
int m_iContinuousLimit;
int m_iNumberOfTreeCheckedLeaves;
int m_iNumberOfCheckedLeaves;
int m_iMaxCheck;
// Prioriy queue used for neighborhood graph
Heap<HeapCell> m_NGQueue;
// Priority queue Used for BKT-Tree
Heap<HeapCell> m_SPTQueue;
};
}
}
#endif // _SPTAG_COMMON_WORKSPACE_H_

View File

@ -1,43 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMON_WORKSPACEPOOL_H_
#define _SPTAG_COMMON_WORKSPACEPOOL_H_
#include "WorkSpace.h"
#include <list>
#include <mutex>
namespace SPTAG
{
namespace COMMON
{
class WorkSpacePool
{
public:
WorkSpacePool(int p_maxCheck, int p_vectorCount);
virtual ~WorkSpacePool();
std::shared_ptr<WorkSpace> Rent();
void Return(const std::shared_ptr<WorkSpace>& p_workSpace);
void Init(int size);
private:
std::list<std::shared_ptr<WorkSpace>> m_workSpacePool;
std::mutex m_workSpacePoolMutex;
int m_maxCheck;
int m_vectorCount;
};
}
}
#endif // _SPTAG_COMMON_WORKSPACEPOOL_H_

View File

@ -1,56 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_COMMONDATASTRUCTURE_H_
#define _SPTAG_COMMONDATASTRUCTURE_H_
#include "Common.h"
namespace SPTAG
{
class ByteArray
{
public:
ByteArray();
ByteArray(ByteArray&& p_right);
ByteArray(std::uint8_t* p_array, std::size_t p_length, bool p_transferOnwership);
ByteArray(std::uint8_t* p_array, std::size_t p_length, std::shared_ptr<std::uint8_t> p_dataHolder);
ByteArray(const ByteArray& p_right);
ByteArray& operator= (const ByteArray& p_right);
ByteArray& operator= (ByteArray&& p_right);
~ByteArray();
static ByteArray Alloc(std::size_t p_length);
std::uint8_t* Data() const;
std::size_t Length() const;
void SetData(std::uint8_t* p_array, std::size_t p_length);
std::shared_ptr<std::uint8_t> DataHolder() const;
void Clear();
const static ByteArray c_empty;
private:
std::uint8_t* m_data;
std::size_t m_length;
// Notice this is holding an array. Set correct deleter for this.
std::shared_ptr<std::uint8_t> m_dataHolder;
};
} // namespace SPTAG
#endif // _SPTAG_COMMONDATASTRUCTURE_H_

View File

@ -1,57 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineVectorValueType
DefineVectorValueType(Int8, std::int8_t)
DefineVectorValueType(UInt8, std::uint8_t)
DefineVectorValueType(Int16, std::int16_t)
DefineVectorValueType(Float, float)
#endif // DefineVectorValueType
#ifdef DefineDistCalcMethod
DefineDistCalcMethod(L2)
DefineDistCalcMethod(Cosine)
#endif // DefineDistCalcMethod
#ifdef DefineErrorCode
// 0x0000 ~ 0x0FFF General Status
DefineErrorCode(Success, 0x0000)
DefineErrorCode(Fail, 0x0001)
DefineErrorCode(FailedOpenFile, 0x0002)
DefineErrorCode(FailedCreateFile, 0x0003)
DefineErrorCode(ParamNotFound, 0x0010)
DefineErrorCode(FailedParseValue, 0x0011)
// 0x1000 ~ 0x1FFF Index Build Status
// 0x2000 ~ 0x2FFF Index Serve Status
// 0x3000 ~ 0x3FFF Helper Function Status
DefineErrorCode(ReadIni_FailedParseSection, 0x3000)
DefineErrorCode(ReadIni_FailedParseParam, 0x3001)
DefineErrorCode(ReadIni_DuplicatedSection, 0x3002)
DefineErrorCode(ReadIni_DuplicatedParam, 0x3003)
// 0x4000 ~ 0x4FFF Socket Library Status
DefineErrorCode(Socket_FailedResolveEndPoint, 0x4000)
DefineErrorCode(Socket_FailedConnectToEndPoint, 0x4001)
#endif // DefineErrorCode
#ifdef DefineIndexAlgo
DefineIndexAlgo(BKT)
DefineIndexAlgo(KDT)
#endif // DefineIndexAlgo

View File

@ -1,112 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_KDT_INDEX_H_
#define _SPTAG_KDT_INDEX_H_
#include "../Common.h"
#include "../VectorIndex.h"
#include "../Common/CommonUtils.h"
#include "../Common/DistanceUtils.h"
#include "../Common/QueryResultSet.h"
#include "../Common/Dataset.h"
#include "../Common/WorkSpace.h"
#include "../Common/WorkSpacePool.h"
#include "../Common/RelativeNeighborhoodGraph.h"
#include "../Common/KDTree.h"
#include "inc/Helper/StringConvert.h"
#include "inc/Helper/SimpleIniReader.h"
#include <functional>
#include <mutex>
#include <tbb/concurrent_unordered_set.h>
namespace SPTAG
{
namespace Helper
{
class IniReader;
}
namespace KDT
{
template<typename T>
class Index : public VectorIndex
{
private:
// data points
COMMON::Dataset<T> m_pSamples;
// KDT structures.
COMMON::KDTree m_pTrees;
// Graph structure
COMMON::RelativeNeighborhoodGraph m_pGraph;
std::string m_sKDTFilename;
std::string m_sGraphFilename;
std::string m_sDataPointsFilename;
std::mutex m_dataLock; // protect data and graph
tbb::concurrent_unordered_set<int> m_deletedID;
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
int m_iNumberOfThreads;
DistCalcMethod m_iDistCalcMethod;
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
int m_iMaxCheck;
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
int m_iNumberOfInitialDynamicPivots;
int m_iNumberOfOtherDynamicPivots;
public:
Index()
{
#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \
VarName = DefaultValue; \
#include "inc/Core/KDT/ParameterDefinitionList.h"
#undef DefineKDTParameter
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
}
~Index() {}
inline int GetNumSamples() const { return m_pSamples.R(); }
inline int GetFeatureDim() const { return m_pSamples.C(); }
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
inline int GetNumThreads() const { return m_iNumberOfThreads; }
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; }
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
ErrorCode SearchIndex(QueryResult &p_query) const;
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
ErrorCode SetParameter(const char* p_param, const char* p_value);
std::string GetParameter(const char* p_param) const;
private:
ErrorCode RefineIndex(const std::string& p_folderPath);
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
};
} // namespace KDT
} // namespace SPTAG
#endif // _SPTAG_KDT_INDEX_H_

View File

@ -1,34 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifdef DefineKDTParameter
// DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr)
DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
DefineKDTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
DefineKDTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
DefineKDTParameter(m_pTrees.m_iTreeNumber, int, 1L, "KDTNumber")
DefineKDTParameter(m_pTrees.m_numTopDimensionKDTSplit, int, 5L, "NumTopDimensionKDTSplit")
DefineKDTParameter(m_pTrees.m_iSamples, int, 100L, "NumSamplesKDTSplitConsideration")
DefineKDTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
DefineKDTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
DefineKDTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTPTSplit")
DefineKDTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
DefineKDTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
DefineKDTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
DefineKDTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
DefineKDTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
DefineKDTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
DefineKDTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
DefineKDTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
DefineKDTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
DefineKDTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
DefineKDTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
DefineKDTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
#endif

View File

@ -1,114 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_METADATASET_H_
#define _SPTAG_METADATASET_H_
#include "CommonDataStructure.h"
#include <iostream>
#include <fstream>
namespace SPTAG
{
class MetadataSet
{
public:
MetadataSet();
virtual ~MetadataSet();
virtual ByteArray GetMetadata(IndexType p_vectorID) const = 0;
virtual SizeType Count() const = 0;
virtual bool Available() const = 0;
virtual void AddBatch(MetadataSet& data) = 0;
virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0;
virtual ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len) = 0;
virtual ErrorCode LoadMetadataFromMemory(void *pGraphMemFile) = 0;
virtual ErrorCode RefineMetadata(std::vector<int>& indices, const std::string& p_folderPath);
static ErrorCode MetaCopy(const std::string& p_src, const std::string& p_dst);
};
class FileMetadataSet : public MetadataSet
{
public:
FileMetadataSet(const std::string& p_metaFile, const std::string& p_metaindexFile);
~FileMetadataSet();
ByteArray GetMetadata(IndexType p_vectorID) const;
SizeType Count() const;
bool Available() const;
void AddBatch(MetadataSet& data);
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len);
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
private:
std::ifstream* m_fp = nullptr;
std::vector<std::uint64_t> m_pOffsets;
SizeType m_count;
std::string m_metaFile;
std::string m_metaindexFile;
std::vector<std::uint8_t> m_newdata;
};
class MemMetadataSet : public MetadataSet
{
public:
MemMetadataSet() = default;
MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count);
~MemMetadataSet();
ByteArray GetMetadata(IndexType p_vectorID) const;
SizeType Count() const;
bool Available() const;
void AddBatch(MetadataSet& data);
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len);
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
private:
std::vector<std::uint64_t> m_offsets;
SizeType m_count;
ByteArray m_metadataHolder;
ByteArray m_offsetHolder;
std::vector<std::uint8_t> m_newdata;
};
} // namespace SPTAG
#endif // _SPTAG_METADATASET_H_

View File

@ -1,239 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SEARCHQUERY_H_
#define _SPTAG_SEARCHQUERY_H_
#include "CommonDataStructure.h"
#include <cstring>
namespace SPTAG
{
struct BasicResult
{
int VID;
float Dist;
BasicResult() : VID(-1), Dist(MaxDist) {}
BasicResult(int p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {}
};
// Space to save temporary answer, similar with TopKCache
class QueryResult
{
public:
typedef BasicResult* iterator;
typedef const BasicResult* const_iterator;
QueryResult()
: m_target(nullptr),
m_resultNum(0),
m_withMeta(false)
{
}
QueryResult(const void* p_target, int p_resultNum, bool p_withMeta)
: m_target(nullptr),
m_resultNum(0),
m_withMeta(false)
{
Init(p_target, p_resultNum, p_withMeta);
}
QueryResult(const void* p_target, int p_resultNum, std::vector<BasicResult>& p_results)
: m_target(p_target),
m_resultNum(p_resultNum),
m_withMeta(false)
{
p_results.resize(p_resultNum);
m_results.reset(p_results.data());
}
QueryResult(const QueryResult& p_other)
: m_target(p_other.m_target),
m_resultNum(p_other.m_resultNum),
m_withMeta(p_other.m_withMeta)
{
if (m_resultNum > 0)
{
m_results.reset(new BasicResult[m_resultNum]);
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
if (m_withMeta)
{
m_metadatas.reset(new ByteArray[m_resultNum]);
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
}
}
}
QueryResult& operator=(const QueryResult& p_other)
{
Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta);
if (m_resultNum > 0)
{
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
if (m_withMeta)
{
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
}
}
return *this;
}
~QueryResult()
{
}
inline void Init(const void* p_target, int p_resultNum, bool p_withMeta)
{
m_target = p_target;
if (p_resultNum > m_resultNum)
{
m_results.reset(new BasicResult[p_resultNum]);
}
if (p_withMeta && (!m_withMeta || p_resultNum > m_resultNum))
{
m_metadatas.reset(new ByteArray[p_resultNum]);
}
m_resultNum = p_resultNum;
m_withMeta = p_withMeta;
}
inline int GetResultNum() const
{
return m_resultNum;
}
inline const void* GetTarget()
{
return m_target;
}
inline void SetTarget(const void* p_target)
{
m_target = p_target;
}
inline BasicResult* GetResult(int i) const
{
return i < m_resultNum ? m_results.get() + i : nullptr;
}
inline void SetResult(int p_index, int p_VID, float p_dist)
{
if (p_index < m_resultNum)
{
m_results[p_index].VID = p_VID;
m_results[p_index].Dist = p_dist;
}
}
inline BasicResult* GetResults() const
{
return m_results.get();
}
inline bool WithMeta() const
{
return m_withMeta;
}
inline const ByteArray& GetMetadata(int p_index) const
{
if (p_index < m_resultNum && m_withMeta)
{
return m_metadatas[p_index];
}
return ByteArray::c_empty;
}
inline void SetMetadata(int p_index, ByteArray p_metadata)
{
if (p_index < m_resultNum && m_withMeta)
{
m_metadatas[p_index] = std::move(p_metadata);
}
}
inline void Reset()
{
for (int i = 0; i < m_resultNum; i++)
{
m_results[i].VID = -1;
m_results[i].Dist = MaxDist;
}
if (m_withMeta)
{
for (int i = 0; i < m_resultNum; i++)
{
m_metadatas[i].Clear();
}
}
}
iterator begin()
{
return m_results.get();
}
iterator end()
{
return m_results.get() + m_resultNum;
}
const_iterator begin() const
{
return m_results.get();
}
const_iterator end() const
{
return m_results.get() + m_resultNum;
}
protected:
const void* m_target;
int m_resultNum;
bool m_withMeta;
std::unique_ptr<BasicResult[]> m_results;
std::unique_ptr<ByteArray[]> m_metadatas;
};
} // namespace SPTAG
#endif // _SPTAG_SEARCHQUERY_H_

View File

@ -1,94 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_VECTORINDEX_H_
#define _SPTAG_VECTORINDEX_H_
#include "Common.h"
#include "SearchQuery.h"
#include "VectorSet.h"
#include "MetadataSet.h"
#include "inc/Helper/SimpleIniReader.h"
namespace SPTAG
{
class VectorIndex
{
public:
VectorIndex();
virtual ~VectorIndex();
virtual ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout) = 0;
virtual ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader) = 0;
virtual ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen) = 0;
virtual ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs) = 0;
virtual ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0;
virtual ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) = 0;
virtual ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum) = 0;
//virtual ErrorCode AddIndexWithID(const void* p_vector, const int& p_id) = 0;
//virtual ErrorCode DeleteIndexWithID(const void* p_vector, const int& p_id) = 0;
virtual float ComputeDistance(const void* pX, const void* pY) const = 0;
virtual const void* GetSample(const int idx) const = 0;
virtual int GetFeatureDim() const = 0;
virtual int GetNumSamples() const = 0;
virtual DistCalcMethod GetDistCalcMethod() const = 0;
virtual IndexAlgoType GetIndexAlgoType() const = 0;
virtual VectorValueType GetVectorValueType() const = 0;
virtual int GetNumThreads() const = 0;
virtual std::string GetParameter(const char* p_param) const = 0;
virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0;
virtual ErrorCode LoadIndex(const std::string& p_folderPath);
virtual ErrorCode SaveIndex(const std::string& p_folderPath);
virtual ErrorCode BuildIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, std::vector<BasicResult>& p_results) const;
virtual ErrorCode AddIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
virtual std::string GetParameter(const std::string& p_param) const;
virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value);
virtual ByteArray GetMetadata(IndexType p_vectorID) const;
virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath);
virtual std::string GetIndexName() const
{
if (m_sIndexName == "")
return Helper::Convert::ConvertToString(GetIndexAlgoType());
return m_sIndexName;
}
virtual void SetIndexName(std::string p_name) { m_sIndexName = p_name; }
static std::shared_ptr<VectorIndex> CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype);
static ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2);
static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr<VectorIndex>& p_vectorIndex);
protected:
std::string m_sIndexName;
std::shared_ptr<MetadataSet> m_pMetadata;
};
} // namespace SPTAG
#endif // _SPTAG_VECTORINDEX_H_

View File

@ -1,73 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_VECTORSET_H_
#define _SPTAG_VECTORSET_H_
#include "CommonDataStructure.h"
namespace SPTAG
{
class VectorSet
{
public:
VectorSet();
virtual ~VectorSet();
virtual VectorValueType GetValueType() const = 0;
virtual void* GetVector(IndexType p_vectorID) const = 0;
virtual void* GetData() const = 0;
virtual SizeType Dimension() const = 0;
virtual SizeType Count() const = 0;
virtual bool Available() const = 0;
virtual ErrorCode Save(const std::string& p_vectorFile) const = 0;
};
class BasicVectorSet : public VectorSet
{
public:
BasicVectorSet(const ByteArray& p_bytesArray,
VectorValueType p_valueType,
SizeType p_dimension,
SizeType p_vectorCount);
virtual ~BasicVectorSet();
virtual VectorValueType GetValueType() const;
virtual void* GetVector(IndexType p_vectorID) const;
virtual void* GetData() const;
virtual SizeType Dimension() const;
virtual SizeType Count() const;
virtual bool Available() const;
virtual ErrorCode Save(const std::string& p_vectorFile) const;
private:
ByteArray m_data;
VectorValueType m_valueType;
SizeType m_dimension;
SizeType m_vectorCount;
SizeType m_perVectorDataSize;
};
} // namespace SPTAG
#endif // _SPTAG_VECTORSET_H_

View File

@ -1,253 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_ARGUMENTSPARSER_H_
#define _SPTAG_HELPER_ARGUMENTSPARSER_H_
#include "inc/Helper/StringConvert.h"
#include <cstdint>
#include <cstddef>
#include <memory>
#include <vector>
#include <string>
namespace SPTAG
{
namespace Helper
{
class ArgumentsParser
{
public:
ArgumentsParser();
virtual ~ArgumentsParser();
virtual bool Parse(int p_argc, char** p_args);
virtual void PrintHelp();
protected:
class IArgument
{
public:
IArgument();
virtual ~IArgument();
virtual bool ParseValue(int& p_restArgc, char** (&p_args)) = 0;
virtual void PrintDescription(FILE* p_output) = 0;
virtual bool IsRequiredButNotSet() const = 0;
};
template<typename DataType>
class ArgumentT : public IArgument
{
public:
ArgumentT(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description,
bool p_followedValue,
const DataType& p_switchAsValue,
bool p_isRequired)
: m_value(p_target),
m_representStringShort(p_representStringShort),
m_representString(p_representString),
m_description(p_description),
m_followedValue(p_followedValue),
c_switchAsValue(p_switchAsValue),
m_isRequired(p_isRequired),
m_isSet(false)
{
}
virtual ~ArgumentT()
{
}
virtual bool ParseValue(int& p_restArgc, char** (&p_args))
{
if (0 == p_restArgc)
{
return true;
}
if (0 != strcmp(*p_args, m_representString.c_str())
&& 0 != strcmp(*p_args, m_representStringShort.c_str()))
{
return true;
}
if (!m_followedValue)
{
m_value = c_switchAsValue;
--p_restArgc;
++p_args;
m_isSet = true;
return true;
}
if (p_restArgc < 2)
{
return false;
}
DataType tmp;
if (!Helper::Convert::ConvertStringTo(p_args[1], tmp))
{
return false;
}
m_value = std::move(tmp);
p_restArgc -= 2;
p_args += 2;
m_isSet = true;
return true;
}
virtual void PrintDescription(FILE* p_output)
{
std::size_t padding = 30;
if (!m_representStringShort.empty())
{
fprintf(p_output, "%s", m_representStringShort.c_str());
padding -= m_representStringShort.size();
}
if (!m_representString.empty())
{
if (!m_representStringShort.empty())
{
fprintf(p_output, ", ");
padding -= 2;
}
fprintf(p_output, "%s", m_representString.c_str());
padding -= m_representString.size();
}
if (m_followedValue)
{
fprintf(p_output, " <value>");
padding -= 8;
}
while (padding-- > 0)
{
fputc(' ', p_output);
}
fprintf(p_output, "%s", m_description.c_str());
}
virtual bool IsRequiredButNotSet() const
{
return m_isRequired && !m_isSet;
}
private:
DataType & m_value;
std::string m_representStringShort;
std::string m_representString;
std::string m_description;
bool m_followedValue;
const DataType c_switchAsValue;
bool m_isRequired;
bool m_isSet;
};
template<typename DataType>
void AddRequiredOption(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
true,
DataType(),
true)));
}
template<typename DataType>
void AddOptionalOption(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
true,
DataType(),
false)));
}
template<typename DataType>
void AddRequiredSwitch(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description,
const DataType& p_switchAsValue)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
false,
p_switchAsValue,
true)));
}
template<typename DataType>
void AddOptionalSwitch(DataType& p_target,
const std::string& p_representStringShort,
const std::string& p_representString,
const std::string& p_description,
const DataType& p_switchAsValue)
{
m_arguments.emplace_back(std::shared_ptr<IArgument>(
new ArgumentT<DataType>(p_target,
p_representStringShort,
p_representString,
p_description,
false,
p_switchAsValue,
false)));
}
private:
std::vector<std::shared_ptr<IArgument>> m_arguments;
};
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_ARGUMENTSPARSER_H_

View File

@ -1,33 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_BASE64ENCODE_H_
#define _SPTAG_HELPER_BASE64ENCODE_H_
#include <cstdint>
#include <cstddef>
#include <ostream>
namespace SPTAG
{
namespace Helper
{
namespace Base64
{
bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std::size_t& p_outLen);
bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_out, std::size_t& p_outLen);
bool Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std::size_t& p_outLen);
std::size_t CapacityForEncode(std::size_t p_inLen);
std::size_t CapacityForDecode(std::size_t p_inLen);
} // namespace Base64
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_BASE64ENCODE_H_

View File

@ -1,40 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_COMMONHELPER_H_
#define _SPTAG_HELPER_COMMONHELPER_H_
#include "../Core/Common.h"
#include <string>
#include <vector>
#include <cctype>
#include <functional>
#include <limits>
#include <cerrno>
namespace SPTAG
{
namespace Helper
{
namespace StrUtils
{
void ToLowerInPlace(std::string& p_str);
std::vector<std::string> SplitString(const std::string& p_str, const std::string& p_separator);
std::pair<const char*, const char*> FindTrimmedSegment(const char* p_begin,
const char* p_end,
const std::function<bool(char)>& p_isSkippedChar);
bool StartsWith(const char* p_str, const char* p_prefix);
bool StrEqualIgnoreCase(const char* p_left, const char* p_right);
} // namespace StrUtils
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_COMMONHELPER_H_

View File

@ -1,97 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_CONCURRENT_H_
#define _SPTAG_HELPER_CONCURRENT_H_
#include <atomic>
#include <condition_variable>
#include <mutex>
namespace SPTAG
{
namespace Helper
{
namespace Concurrent
{
class SpinLock
{
public:
SpinLock() = default;
void Lock() noexcept
{
while (m_lock.test_and_set(std::memory_order_acquire))
{
}
}
void Unlock() noexcept
{
m_lock.clear(std::memory_order_release);
}
SpinLock(const SpinLock&) = delete;
SpinLock& operator = (const SpinLock&) = delete;
private:
std::atomic_flag m_lock = ATOMIC_FLAG_INIT;
};
template<typename Lock>
class LockGuard {
public:
LockGuard(Lock& lock) noexcept
: m_lock(lock) {
lock.Lock();
}
LockGuard(Lock& lock, std::adopt_lock_t) noexcept
: m_lock(lock) {}
~LockGuard() {
m_lock.Unlock();
}
LockGuard(const LockGuard&) = delete;
LockGuard& operator=(const LockGuard&) = delete;
private:
Lock& m_lock;
};
class WaitSignal
{
public:
WaitSignal();
WaitSignal(std::uint32_t p_unfinished);
~WaitSignal();
void Reset(std::uint32_t p_unfinished);
void Wait();
void FinishOne();
private:
std::atomic<std::uint32_t> m_unfinished;
std::atomic_bool m_isWaiting;
std::mutex m_mutex;
std::condition_variable m_cv;
};
} // namespace Base64
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_CONCURRENT_H_

View File

@ -1,97 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_INIREADER_H_
#define _SPTAG_HELPER_INIREADER_H_
#include "../Core/Common.h"
#include "StringConvert.h"
#include <vector>
#include <map>
#include <memory>
#include <string>
#include <sstream>
namespace SPTAG
{
namespace Helper
{
// Simple INI Reader with basic functions. Case insensitive.
class IniReader
{
public:
typedef std::map<std::string, std::string> ParameterValueMap;
IniReader();
~IniReader();
ErrorCode LoadIniFile(const std::string& p_iniFilePath);
bool DoesSectionExist(const std::string& p_section) const;
bool DoesParameterExist(const std::string& p_section, const std::string& p_param) const;
const ParameterValueMap& GetParameters(const std::string& p_section) const;
template <typename DataType>
DataType GetParameter(const std::string& p_section, const std::string& p_param, const DataType& p_defaultVal) const;
void SetParameter(const std::string& p_section, const std::string& p_param, const std::string& p_val);
private:
bool GetRawValue(const std::string& p_section, const std::string& p_param, std::string& p_value) const;
template <typename DataType>
static inline DataType ConvertStringTo(std::string&& p_str, const DataType& p_defaultVal);
private:
const static ParameterValueMap c_emptyParameters;
std::map<std::string, std::shared_ptr<ParameterValueMap>> m_parameters;
};
template <typename DataType>
DataType
IniReader::GetParameter(const std::string& p_section, const std::string& p_param, const DataType& p_defaultVal) const
{
std::string value;
if (!GetRawValue(p_section, p_param, value))
{
return p_defaultVal;
}
return ConvertStringTo<DataType>(std::move(value), p_defaultVal);
}
template <typename DataType>
inline DataType
IniReader::ConvertStringTo(std::string&& p_str, const DataType& p_defaultVal)
{
DataType value;
if (Convert::ConvertStringTo<DataType>(p_str.c_str(), value))
{
return value;
}
return p_defaultVal;
}
template <>
inline std::string
IniReader::ConvertStringTo<std::string>(std::string&& p_str, const std::string& p_defaultVal)
{
return std::move(p_str);
}
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_INIREADER_H_

View File

@ -1,374 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_HELPER_STRINGCONVERTHELPER_H_
#define _SPTAG_HELPER_STRINGCONVERTHELPER_H_
#include "inc/Core/Common.h"
#include "CommonHelper.h"
#include <string>
#include <cstring>
#include <sstream>
#include <cctype>
#include <limits>
#include <cerrno>
namespace SPTAG
{
namespace Helper
{
namespace Convert
{
template <typename DataType>
inline bool ConvertStringTo(const char* p_str, DataType& p_value)
{
if (nullptr == p_str)
{
return false;
}
std::istringstream sstream;
sstream.str(p_str);
if (p_str >> p_value)
{
return true;
}
return false;
}
template<typename DataType>
inline std::string ConvertToString(const DataType& p_value)
{
return std::to_string(p_value);
}
// Specialization of ConvertStringTo<>().
template <typename DataType>
inline bool ConvertStringToSignedInt(const char* p_str, DataType& p_value)
{
static_assert(std::is_integral<DataType>::value && std::is_signed<DataType>::value, "type check");
if (nullptr == p_str)
{
return false;
}
char* end = nullptr;
errno = 0;
auto val = std::strtoll(p_str, &end, 10);
if (errno == ERANGE || end == p_str || *end != '\0')
{
return false;
}
if (val < (std::numeric_limits<DataType>::min)() || val >(std::numeric_limits<DataType>::max)())
{
return false;
}
p_value = static_cast<DataType>(val);
return true;
}
template <typename DataType>
inline bool ConvertStringToUnsignedInt(const char* p_str, DataType& p_value)
{
static_assert(std::is_integral<DataType>::value && std::is_unsigned<DataType>::value, "type check");
if (nullptr == p_str)
{
return false;
}
char* end = nullptr;
errno = 0;
auto val = std::strtoull(p_str, &end, 10);
if (errno == ERANGE || end == p_str || *end != '\0')
{
return false;
}
if (val < (std::numeric_limits<DataType>::min)() || val >(std::numeric_limits<DataType>::max)())
{
return false;
}
p_value = static_cast<DataType>(val);
return true;
}
template <>
inline bool ConvertStringTo<std::string>(const char* p_str, std::string& p_value)
{
if (nullptr == p_str)
{
return false;
}
p_value = p_str;
return true;
}
template <>
inline bool ConvertStringTo<float>(const char* p_str, float& p_value)
{
if (nullptr == p_str)
{
return false;
}
char* end = nullptr;
errno = 0;
p_value = std::strtof(p_str, &end);
return (errno != ERANGE && end != p_str && *end == '\0');
}
template <>
inline bool ConvertStringTo<double>(const char* p_str, double& p_value)
{
if (nullptr == p_str)
{
return false;
}
char* end = nullptr;
errno = 0;
p_value = std::strtod(p_str, &end);
return (errno != ERANGE && end != p_str && *end == '\0');
}
template <>
inline bool ConvertStringTo<std::int8_t>(const char* p_str, std::int8_t& p_value)
{
return ConvertStringToSignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::int16_t>(const char* p_str, std::int16_t& p_value)
{
return ConvertStringToSignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::int32_t>(const char* p_str, std::int32_t& p_value)
{
return ConvertStringToSignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::int64_t>(const char* p_str, std::int64_t& p_value)
{
return ConvertStringToSignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::uint8_t>(const char* p_str, std::uint8_t& p_value)
{
return ConvertStringToUnsignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::uint16_t>(const char* p_str, std::uint16_t& p_value)
{
return ConvertStringToUnsignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::uint32_t>(const char* p_str, std::uint32_t& p_value)
{
return ConvertStringToUnsignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<std::uint64_t>(const char* p_str, std::uint64_t& p_value)
{
return ConvertStringToUnsignedInt(p_str, p_value);
}
template <>
inline bool ConvertStringTo<bool>(const char* p_str, bool& p_value)
{
if (StrUtils::StrEqualIgnoreCase(p_str, "true"))
{
p_value = true;
}
else if (StrUtils::StrEqualIgnoreCase(p_str, "false"))
{
p_value = false;
}
else
{
return false;
}
return true;
}
template <>
inline bool ConvertStringTo<IndexAlgoType>(const char* p_str, IndexAlgoType& p_value)
{
if (nullptr == p_str)
{
return false;
}
#define DefineIndexAlgo(Name) \
else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \
{ \
p_value = IndexAlgoType::Name; \
return true; \
} \
#include "inc/Core/DefinitionList.h"
#undef DefineIndexAlgo
return false;
}
template <>
inline bool ConvertStringTo<DistCalcMethod>(const char* p_str, DistCalcMethod& p_value)
{
if (nullptr == p_str)
{
return false;
}
#define DefineDistCalcMethod(Name) \
else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \
{ \
p_value = DistCalcMethod::Name; \
return true; \
} \
#include "inc/Core/DefinitionList.h"
#undef DefineDistCalcMethod
return false;
}
template <>
inline bool ConvertStringTo<VectorValueType>(const char* p_str, VectorValueType& p_value)
{
if (nullptr == p_str)
{
return false;
}
#define DefineVectorValueType(Name, Type) \
else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \
{ \
p_value = VectorValueType::Name; \
return true; \
} \
#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType
return false;
}
// Specialization of ConvertToString<>().
template<>
inline std::string ConvertToString<std::string>(const std::string& p_value)
{
return p_value;
}
template<>
inline std::string ConvertToString<bool>(const bool& p_value)
{
return p_value ? "true" : "false";
}
template <>
inline std::string ConvertToString<IndexAlgoType>(const IndexAlgoType& p_value)
{
switch (p_value)
{
#define DefineIndexAlgo(Name) \
case IndexAlgoType::Name: \
return #Name; \
#include "inc/Core/DefinitionList.h"
#undef DefineIndexAlgo
default:
break;
}
return "Undefined";
}
template <>
inline std::string ConvertToString<DistCalcMethod>(const DistCalcMethod& p_value)
{
switch (p_value)
{
#define DefineDistCalcMethod(Name) \
case DistCalcMethod::Name: \
return #Name; \
#include "inc/Core/DefinitionList.h"
#undef DefineDistCalcMethod
default:
break;
}
return "Undefined";
}
template <>
inline std::string ConvertToString<VectorValueType>(const VectorValueType& p_value)
{
switch (p_value)
{
#define DefineVectorValueType(Name, Type) \
case VectorValueType::Name: \
return #Name; \
#include "inc/Core/DefinitionList.h"
#undef DefineVectorValueType
default:
break;
}
return "Undefined";
}
} // namespace Convert
} // namespace Helper
} // namespace SPTAG
#endif // _SPTAG_HELPER_STRINGCONVERTHELPER_H_

View File

@ -1,47 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_OPTIONS_H_
#define _SPTAG_INDEXBUILDER_OPTIONS_H_
#include "inc/Core/Common.h"
#include "inc/Helper/ArgumentsParser.h"
#include <string>
#include <vector>
#include <memory>
namespace SPTAG
{
namespace IndexBuilder
{
class BuilderOptions : public Helper::ArgumentsParser
{
public:
BuilderOptions();
~BuilderOptions();
std::uint32_t m_threadNum;
std::uint32_t m_dimension;
std::string m_vectorDelimiter;
SPTAG::VectorValueType m_inputValueType;
std::string m_inputFiles;
std::string m_outputFolder;
SPTAG::IndexAlgoType m_indexAlgoType;
std::string m_builderConfigFile;
};
} // namespace IndexBuilder
} // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_OPTIONS_H_

View File

@ -1,27 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_THREADPOOL_H_
#define _SPTAG_INDEXBUILDER_THREADPOOL_H_
#include <functional>
#include <cstdint>
namespace SPTAG
{
namespace IndexBuilder
{
namespace ThreadPool
{
void Init(std::uint32_t p_threadNum);
bool Queue(std::function<void()> p_workItem);
std::uint32_t CurrentThreadNum();
}
} // namespace IndexBuilder
} // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_THREADPOOL_H_

View File

@ -1,43 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_VECTORSETREADER_H_
#define _SPTAG_INDEXBUILDER_VECTORSETREADER_H_
#include "inc/Core/Common.h"
#include "inc/Core/VectorSet.h"
#include "inc/Core/MetadataSet.h"
#include "Options.h"
#include <memory>
namespace SPTAG
{
namespace IndexBuilder
{
class VectorSetReader
{
public:
VectorSetReader(std::shared_ptr<BuilderOptions> p_options);
virtual ~VectorSetReader();
virtual ErrorCode LoadFile(const std::string& p_filePath) = 0;
virtual std::shared_ptr<VectorSet> GetVectorSet() const = 0;
virtual std::shared_ptr<MetadataSet> GetMetadataSet() const = 0;
static std::shared_ptr<VectorSetReader> CreateInstance(std::shared_ptr<BuilderOptions> p_options);
protected:
std::shared_ptr<BuilderOptions> m_options;
};
} // namespace IndexBuilder
} // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_VECTORSETREADER_H_

View File

@ -1,108 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULTREADER_H_
#define _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULTREADER_H_
#include "../VectorSetReader.h"
#include "inc/Helper/Concurrent.h"
#include <atomic>
#include <condition_variable>
#include <mutex>
namespace SPTAG
{
namespace IndexBuilder
{
class DefaultReader : public VectorSetReader
{
public:
DefaultReader(std::shared_ptr<BuilderOptions> p_options);
virtual ~DefaultReader();
virtual ErrorCode LoadFile(const std::string& p_filePaths);
virtual std::shared_ptr<VectorSet> GetVectorSet() const;
virtual std::shared_ptr<MetadataSet> GetMetadataSet() const;
private:
typedef std::pair<std::string, std::size_t> FileInfoPair;
static std::vector<FileInfoPair> GetFileSizes(const std::string& p_filePaths);
void LoadFileInternal(const std::string& p_filePath,
std::uint32_t p_subtaskID,
std::uint32_t p_fileBlockID,
std::size_t p_fileBlockSize);
void MergeData();
template<typename DataType>
bool TranslateVector(char* p_str, DataType* p_vector)
{
std::uint32_t eleCount = 0;
char* next = p_str;
while ((*next) != '\0')
{
while ((*next) != '\0' && m_options->m_vectorDelimiter.find(*next) == std::string::npos)
{
++next;
}
bool reachEnd = ('\0' == (*next));
*next = '\0';
if (p_str != next)
{
if (eleCount >= m_options->m_dimension)
{
return false;
}
if (!Helper::Convert::ConvertStringTo(p_str, p_vector[eleCount++]))
{
return false;
}
}
if (reachEnd)
{
break;
}
++next;
p_str = next;
}
return eleCount == m_options->m_dimension;
}
private:
std::uint32_t m_subTaskCount;
std::size_t m_subTaskBlocksize;
std::atomic<std::uint32_t> m_totalRecordCount;
std::atomic<std::size_t> m_totalRecordVectorBytes;
std::vector<std::uint32_t> m_subTaskRecordCount;
std::string m_vectorOutput;
std::string m_metadataConentOutput;
std::string m_metadataIndexOutput;
Helper::Concurrent::WaitSignal m_waitSignal;
};
} // namespace IndexBuilder
} // namespace SPTAG
#endif // _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULT_H_

View File

@ -1,56 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SERVER_QUERYPARSER_H_
#define _SPTAG_SERVER_QUERYPARSER_H_
#include "../Core/Common.h"
#include "../Core/CommonDataStructure.h"
#include <vector>
namespace SPTAG
{
namespace Service
{
class QueryParser
{
public:
typedef std::pair<const char*, const char*> OptionPair;
QueryParser();
~QueryParser();
ErrorCode Parse(const std::string& p_query, const char* p_vectorSeparator);
const std::vector<const char*>& GetVectorElements() const;
const std::vector<OptionPair>& GetOptions() const;
const char* GetVectorBase64() const;
SizeType GetVectorBase64Length() const;
private:
std::vector<OptionPair> m_options;
std::vector<const char*> m_vectorElements;
const char* m_vectorBase64;
SizeType m_vectorBase64Length;
ByteArray m_dataHolder;
static const char* c_defaultVectorSeparator;
};
} // namespace Server
} // namespace AnnService
#endif // _SPTAG_SERVER_QUERYPARSER_H_

View File

@ -1,80 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_
#define _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_
#include "inc/Core/VectorIndex.h"
#include "inc/Core/SearchQuery.h"
#include "inc/Socket/RemoteSearchQuery.h"
#include "ServiceSettings.h"
#include "QueryParser.h"
#include <vector>
#include <string>
#include <memory>
namespace SPTAG
{
namespace Service
{
typedef Socket::IndexSearchResult SearchResult;
class SearchExecutionContext
{
public:
SearchExecutionContext(const std::shared_ptr<const ServiceSettings>& p_serviceSettings);
~SearchExecutionContext();
ErrorCode ParseQuery(const std::string& p_query);
ErrorCode ExtractOption();
ErrorCode ExtractVector(VectorValueType p_targetType);
void AddResults(std::string p_indexName, QueryResult& p_results);
std::vector<SearchResult>& GetResults();
const std::vector<SearchResult>& GetResults() const;
const ByteArray& GetVector() const;
const std::vector<std::string>& GetSelectedIndexNames() const;
const SizeType GetVectorDimension() const;
const std::vector<QueryParser::OptionPair>& GetOptions() const;
const SizeType GetResultNum() const;
const bool GetExtractMetadata() const;
private:
const std::shared_ptr<const ServiceSettings> c_serviceSettings;
QueryParser m_queryParser;
std::vector<std::string> m_indexNames;
ByteArray m_vector;
SizeType m_vectorDimension;
std::vector<SearchResult> m_results;
VectorValueType m_inputValueType;
bool m_extractMetadata;
SizeType m_resultNum;
};
} // namespace Server
} // namespace AnnService
#endif // _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_

View File

@ -1,56 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SERVER_SEARCHEXECUTOR_H_
#define _SPTAG_SERVER_SEARCHEXECUTOR_H_
#include "ServiceContext.h"
#include "ServiceSettings.h"
#include "SearchExecutionContext.h"
#include "QueryParser.h"
#include <functional>
#include <memory>
#include <vector>
namespace SPTAG
{
namespace Service
{
class SearchExecutor
{
public:
typedef std::function<void(std::shared_ptr<SearchExecutionContext>)> CallBack;
SearchExecutor(std::string p_queryString,
std::shared_ptr<ServiceContext> p_serviceContext,
const CallBack& p_callback);
~SearchExecutor();
void Execute();
private:
void ExecuteInternal();
void SelectIndex();
private:
CallBack m_callback;
const std::shared_ptr<ServiceContext> c_serviceContext;
std::shared_ptr<SearchExecutionContext> m_executionContext;
std::string m_queryString;
std::vector<std::shared_ptr<VectorIndex>> m_selectedIndex;
};
} // namespace Server
} // namespace AnnService
#endif // _SPTAG_SERVER_SEARCHEXECUTOR_H_

View File

@ -1,73 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SERVER_SERVICE_H_
#define _SPTAG_SERVER_SERVICE_H_
#include "ServiceContext.h"
#include "../Socket/Server.h"
#include <boost/asio.hpp>
#include <memory>
#include <vector>
#include <thread>
#include <condition_variable>
namespace SPTAG
{
namespace Service
{
class SearchExecutionContext;
class SearchService
{
public:
SearchService();
~SearchService();
bool Initialize(int p_argNum, char* p_args[]);
void Run();
private:
void RunSocketMode();
void RunInteractiveMode();
void SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
void SearchHanlderCallback(std::shared_ptr<SearchExecutionContext> p_exeContext,
Socket::Packet p_srcPacket);
private:
enum class ServeMode : std::uint8_t
{
Interactive,
Socket
};
std::shared_ptr<ServiceContext> m_serviceContext;
std::shared_ptr<Socket::Server> m_socketServer;
bool m_initialized;
ServeMode m_serveMode;
std::unique_ptr<boost::asio::thread_pool> m_threadPool;
boost::asio::io_context m_ioContext;
boost::asio::signal_set m_shutdownSignals;
};
} // namespace Server
} // namespace AnnService
#endif // _SPTAG_SERVER_SERVICE_H_

View File

@ -1,44 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SERVER_SERVICECONTEX_H_
#define _SPTAG_SERVER_SERVICECONTEX_H_
#include "inc/Core/VectorIndex.h"
#include "ServiceSettings.h"
#include <memory>
#include <map>
namespace SPTAG
{
namespace Service
{
class ServiceContext
{
public:
ServiceContext(const std::string& p_configFilePath);
~ServiceContext();
const std::map<std::string, std::shared_ptr<VectorIndex>>& GetIndexMap() const;
const std::shared_ptr<ServiceSettings>& GetServiceSettings() const;
bool IsInitialized() const;
private:
bool m_initialized;
std::shared_ptr<ServiceSettings> m_settings;
std::map<std::string, std::shared_ptr<VectorIndex>> m_fullIndexList;
};
} // namespace Server
} // namespace AnnService
#endif // _SPTAG_SERVER_SERVICECONTEX_H_

View File

@ -1,41 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SERVER_SERVICESTTINGS_H_
#define _SPTAG_SERVER_SERVICESTTINGS_H_
#include "../Core/Common.h"
#include <string>
namespace SPTAG
{
namespace Service
{
struct ServiceSettings
{
ServiceSettings();
std::string m_vectorSeparator;
std::string m_listenAddr;
std::string m_listenPort;
SizeType m_defaultMaxResultNumber;
SizeType m_threadNum;
SizeType m_socketThreadNum;
};
} // namespace Server
} // namespace AnnService
#endif // _SPTAG_SERVER_SERVICESTTINGS_H_

View File

@ -1,68 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_CLIENT_H_
#define _SPTAG_SOCKET_CLIENT_H_
#include "inc/Core/Common.h"
#include "Connection.h"
#include "ConnectionManager.h"
#include "Packet.h"
#include <string>
#include <memory>
#include <atomic>
#include <boost/asio.hpp>
namespace SPTAG
{
namespace Socket
{
class Client
{
public:
typedef std::function<void(ConnectionID p_cid, SPTAG::ErrorCode)> ConnectCallback;
Client(const PacketHandlerMapPtr& p_handlerMap,
std::size_t p_threadNum,
std::uint32_t p_heartbeatIntervalSeconds);
~Client();
ConnectionID ConnectToServer(const std::string& p_address,
const std::string& p_port,
SPTAG::ErrorCode& p_ec);
void AsyncConnectToServer(const std::string& p_address,
const std::string& p_port,
ConnectCallback p_callback);
void SendPacket(ConnectionID p_connection, Packet p_packet, std::function<void(bool)> p_callback);
void SetEventOnConnectionClose(std::function<void(ConnectionID)> p_event);
private:
void KeepIoContext();
private:
std::atomic_bool m_stopped;
std::uint32_t m_heartbeatIntervalSeconds;
boost::asio::io_context m_ioContext;
boost::asio::deadline_timer m_deadlineTimer;
std::shared_ptr<ConnectionManager> m_connectionManager;
std::vector<std::thread> m_threadPool;
const PacketHandlerMapPtr c_requestHandlerMap;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_SOCKET_CLIENT_H_

View File

@ -1,25 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_COMMON_H_
#define _SPTAG_SOCKET_COMMON_H_
#include <cstdint>
namespace SPTAG
{
namespace Socket
{
typedef std::uint32_t ConnectionID;
typedef std::uint32_t ResourceID;
extern const ConnectionID c_invalidConnectionID;
extern const ResourceID c_invalidResourceID;
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_SOCKET_COMMON_H_

View File

@ -1,100 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_CONNECTION_H_
#define _SPTAG_SOCKET_CONNECTION_H_
#include "Packet.h"
#include <cstdint>
#include <memory>
#include <atomic>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/io_service_strand.hpp>
#include <boost/asio/deadline_timer.hpp>
namespace SPTAG
{
namespace Socket
{
class ConnectionManager;
class Connection : public std::enable_shared_from_this<Connection>
{
public:
typedef std::shared_ptr<Connection> Ptr;
Connection(ConnectionID p_connectionID,
boost::asio::ip::tcp::socket&& p_socket,
const PacketHandlerMapPtr& p_handlerMap,
std::weak_ptr<ConnectionManager> p_connectionManager);
void Start();
void Stop();
void StartHeartbeat(std::size_t p_intervalSeconds);
void AsyncSend(Packet p_packet, std::function<void(bool)> p_callback);
ConnectionID GetConnectionID() const;
ConnectionID GetRemoteConnectionID() const;
Connection(const Connection&) = delete;
Connection& operator=(const Connection&) = delete;
private:
void AsyncReadHeader();
void AsyncReadBody();
void HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytesTransferred);
void HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTransferred);
void SendHeartbeat(std::size_t p_intervalSeconds);
void SendRegister();
void HandleHeartbeatRequest();
void HandleRegisterRequest();
void HandleRegisterResponse();
void HandleNoHandlerResponse();
void OnConnectionFail(const boost::system::error_code& p_ec);
private:
const ConnectionID c_connectionID;
ConnectionID m_remoteConnectionID;
const std::weak_ptr<ConnectionManager> c_connectionManager;
const PacketHandlerMapPtr c_handlerMap;
boost::asio::ip::tcp::socket m_socket;
boost::asio::io_context::strand m_strand;
boost::asio::deadline_timer m_heartbeatTimer;
std::uint8_t m_packetHeaderReadBuffer[PacketHeader::c_bufferSize];
Packet m_packetRead;
std::atomic_bool m_stopped;
std::atomic_bool m_heartbeatStarted;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_SOCKET_CONNECTION_H_

View File

@ -1,73 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_CONNECTIONMANAGER_H_
#define _SPTAG_SOCKET_CONNECTIONMANAGER_H_
#include "Connection.h"
#include "inc/Helper/Concurrent.h"
#include <cstdint>
#include <memory>
#include <atomic>
#include <mutex>
#include <array>
#include <boost/asio/ip/tcp.hpp>
namespace SPTAG
{
namespace Socket
{
class ConnectionManager : public std::enable_shared_from_this<ConnectionManager>
{
public:
ConnectionManager();
ConnectionID AddConnection(boost::asio::ip::tcp::socket&& p_socket,
const PacketHandlerMapPtr& p_handlerMap,
std::uint32_t p_heartbeatIntervalSeconds);
void RemoveConnection(ConnectionID p_connectionID);
Connection::Ptr GetConnection(ConnectionID p_connectionID);
void SetEventOnRemoving(std::function<void(ConnectionID)> p_event);
void StopAll();
private:
inline static std::uint32_t GetPosition(ConnectionID p_connectionID);
private:
static constexpr std::uint32_t c_connectionPoolSize = 1 << 8;
static constexpr std::uint32_t c_connectionPoolMask = c_connectionPoolSize - 1;
struct ConnectionItem
{
ConnectionItem();
std::atomic_bool m_isEmpty;
Connection::Ptr m_connection;
};
// Start from 1. 0 means not assigned.
std::atomic<ConnectionID> m_nextConnectionID;
std::atomic<std::uint32_t> m_connectionCount;
std::array<ConnectionItem, c_connectionPoolSize> m_connections;
Helper::Concurrent::SpinLock m_spinLock;
std::function<void(ConnectionID)> m_eventOnRemoving;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_SOCKET_CONNECTIONMANAGER_H_

View File

@ -1,142 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_PACKET_H_
#define _SPTAG_SOCKET_PACKET_H_
#include "Common.h"
#include <cstdint>
#include <memory>
#include <functional>
#include <array>
#include <unordered_map>
namespace SPTAG
{
namespace Socket
{
enum class PacketType : std::uint8_t
{
Undefined = 0x00,
HeartbeatRequest = 0x01,
RegisterRequest = 0x02,
SearchRequest = 0x03,
ResponseMask = 0x80,
HeartbeatResponse = ResponseMask | HeartbeatRequest,
RegisterResponse = ResponseMask | RegisterRequest,
SearchResponse = ResponseMask | SearchRequest
};
enum class PacketProcessStatus : std::uint8_t
{
Ok = 0x00,
Timeout = 0x01,
Dropped = 0x02,
Failed = 0x03
};
struct PacketHeader
{
static constexpr std::size_t c_bufferSize = 16;
PacketHeader();
PacketHeader(PacketHeader&& p_right);
PacketHeader(const PacketHeader& p_right);
std::size_t WriteBuffer(std::uint8_t* p_buffer);
void ReadBuffer(const std::uint8_t* p_buffer);
PacketType m_packetType;
PacketProcessStatus m_processStatus;
std::uint32_t m_bodyLength;
// Meaning of this is different with different PacketType.
// In most request case, it means connection expeced for response.
// In most response case, it means connection which handled request.
ConnectionID m_connectionID;
ResourceID m_resourceID;
};
static_assert(sizeof(PacketHeader) <= PacketHeader::c_bufferSize, "");
class Packet
{
public:
Packet();
Packet(Packet&& p_right);
Packet(const Packet& p_right);
PacketHeader& Header();
std::uint8_t* HeaderBuffer() const;
std::uint8_t* Body() const;
std::uint8_t* Buffer() const;
std::uint32_t BufferLength() const;
std::uint32_t BufferCapacity() const;
void AllocateBuffer(std::uint32_t p_bodyCapacity);
private:
PacketHeader m_header;
std::shared_ptr<std::uint8_t> m_buffer;
std::uint32_t m_bufferCapacity;
};
struct PacketTypeHash
{
std::size_t operator()(const PacketType& p_val) const
{
return static_cast<std::size_t>(p_val);
}
};
typedef std::function<void(ConnectionID, Packet)> PacketHandler;
typedef std::unordered_map<PacketType, PacketHandler, PacketTypeHash> PacketHandlerMap;
typedef std::shared_ptr<PacketHandlerMap> PacketHandlerMapPtr;
namespace PacketTypeHelper
{
bool IsRequestPacket(PacketType p_type);
bool IsResponsePacket(PacketType p_type);
PacketType GetCrosspondingResponseType(PacketType p_type);
}
} // namespace SPTAG
} // namespace Socket
#endif // _SPTAG_SOCKET_SOCKETSERVER_H_

View File

@ -1,99 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_REMOTESEARCHQUERY_H_
#define _SPTAG_SOCKET_REMOTESEARCHQUERY_H_
#include "inc/Core/CommonDataStructure.h"
#include "inc/Core/SearchQuery.h"
#include <cstdint>
#include <memory>
#include <functional>
#include <vector>
#include <unordered_map>
namespace SPTAG
{
namespace Socket
{
// TODO: use Bond replace below structures.
struct RemoteQuery
{
static constexpr std::uint16_t MajorVersion() { return 1; }
static constexpr std::uint16_t MirrorVersion() { return 0; }
enum class QueryType : std::uint8_t
{
String = 0
};
RemoteQuery();
std::size_t EstimateBufferSize() const;
std::uint8_t* Write(std::uint8_t* p_buffer) const;
const std::uint8_t* Read(const std::uint8_t* p_buffer);
QueryType m_type;
std::string m_queryString;
};
struct IndexSearchResult
{
std::string m_indexName;
QueryResult m_results;
};
struct RemoteSearchResult
{
static constexpr std::uint16_t MajorVersion() { return 1; }
static constexpr std::uint16_t MirrorVersion() { return 0; }
enum class ResultStatus : std::uint8_t
{
Success = 0,
Timeout = 1,
FailedNetwork = 2,
FailedExecute = 3,
Dropped = 4
};
RemoteSearchResult();
RemoteSearchResult(const RemoteSearchResult& p_right);
RemoteSearchResult(RemoteSearchResult&& p_right);
RemoteSearchResult& operator=(RemoteSearchResult&& p_right);
std::size_t EstimateBufferSize() const;
std::uint8_t* Write(std::uint8_t* p_buffer) const;
const std::uint8_t* Read(const std::uint8_t* p_buffer);
ResultStatus m_status;
std::vector<IndexSearchResult> m_allIndexResults;
};
} // namespace SPTAG
} // namespace Socket
#endif // _SPTAG_SOCKET_REMOTESEARCHQUERY_H_

View File

@ -1,190 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_RESOURCEMANAGER_H_
#define _SPTAG_SOCKET_RESOURCEMANAGER_H_
#include "Common.h"
#include <boost/asio/io_context.hpp>
#include <memory>
#include <chrono>
#include <functional>
#include <atomic>
#include <mutex>
#include <deque>
#include <unordered_map>
#include <thread>
namespace std
{
typedef atomic<uint32_t> atomic_uint32_t;
}
namespace SPTAG
{
namespace Socket
{
template<typename ResourceType>
class ResourceManager : public std::enable_shared_from_this<ResourceManager<ResourceType>>
{
public:
typedef std::function<void(std::shared_ptr<ResourceType>)> TimeoutCallback;
ResourceManager()
: m_nextResourceID(1),
m_isStopped(false),
m_timeoutItemCount(0)
{
m_timeoutChecker = std::thread(&ResourceManager::StartCheckTimeout, this);
}
~ResourceManager()
{
m_isStopped = true;
m_timeoutChecker.join();
}
ResourceID Add(const std::shared_ptr<ResourceType>& p_resource,
std::uint32_t p_timeoutMilliseconds,
TimeoutCallback p_timeoutCallback)
{
ResourceID rid = m_nextResourceID.fetch_add(1);
while (c_invalidResourceID == rid)
{
rid = m_nextResourceID.fetch_add(1);
}
{
std::lock_guard<std::mutex> guard(m_resourcesMutex);
m_resources.emplace(rid, p_resource);
}
if (p_timeoutMilliseconds > 0)
{
std::unique_ptr<ResourceItem> item(new ResourceItem);
item->m_resourceID = rid;
item->m_callback = std::move(p_timeoutCallback);
item->m_expireTime = m_clock.now() + std::chrono::milliseconds(p_timeoutMilliseconds);
{
std::lock_guard<std::mutex> guard(m_timeoutListMutex);
m_timeoutList.emplace_back(std::move(item));
}
++m_timeoutItemCount;
}
return rid;
}
std::shared_ptr<ResourceType> GetAndRemove(ResourceID p_resourceID)
{
std::shared_ptr<ResourceType> ret;
std::lock_guard<std::mutex> guard(m_resourcesMutex);
auto iter = m_resources.find(p_resourceID);
if (iter != m_resources.end())
{
ret = iter->second;
m_resources.erase(iter);
}
return ret;
}
void Remove(ResourceID p_resourceID)
{
std::lock_guard<std::mutex> guard(m_resourcesMutex);
auto iter = m_resources.find(p_resourceID);
if (iter != m_resources.end())
{
m_resources.erase(iter);
}
}
private:
void StartCheckTimeout()
{
std::vector<std::unique_ptr<ResourceItem>> timeouted;
timeouted.reserve(1024);
while (!m_isStopped)
{
if (m_timeoutItemCount > 0)
{
std::lock_guard<std::mutex> guard(m_timeoutListMutex);
while (!m_timeoutList.empty()
&& m_timeoutList.front()->m_expireTime <= m_clock.now())
{
timeouted.emplace_back(std::move(m_timeoutList.front()));
m_timeoutList.pop_front();
--m_timeoutItemCount;
}
}
if (timeouted.empty())
{
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
else
{
for (auto& item : timeouted)
{
auto resource = GetAndRemove(item->m_resourceID);
if (nullptr != resource)
{
item->m_callback(std::move(resource));
}
}
timeouted.clear();
}
}
}
private:
struct ResourceItem
{
ResourceItem()
: m_resourceID(c_invalidResourceID)
{
}
ResourceID m_resourceID;
TimeoutCallback m_callback;
std::chrono::time_point<std::chrono::high_resolution_clock> m_expireTime;
};
std::deque<std::unique_ptr<ResourceItem>> m_timeoutList;
std::atomic<std::uint32_t> m_timeoutItemCount;
std::mutex m_timeoutListMutex;
std::unordered_map<ResourceID, std::shared_ptr<ResourceType>> m_resources;
std::atomic<ResourceID> m_nextResourceID;
std::mutex m_resourcesMutex;
std::chrono::high_resolution_clock m_clock;
std::thread m_timeoutChecker;
bool m_isStopped;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_SOCKET_RESOURCEMANAGER_H_

View File

@ -1,55 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_SERVER_H_
#define _SPTAG_SOCKET_SERVER_H_
#include "Connection.h"
#include "ConnectionManager.h"
#include "Packet.h"
#include <string>
#include <memory>
#include <boost/asio.hpp>
namespace SPTAG
{
namespace Socket
{
class Server
{
public:
Server(const std::string& p_address,
const std::string& p_port,
const PacketHandlerMapPtr& p_handlerMap,
std::size_t p_threadNum);
~Server();
void StartListen();
void SendPacket(ConnectionID p_connection, Packet p_packet, std::function<void(bool)> p_callback);
void SetEventOnConnectionClose(std::function<void(ConnectionID)> p_event);
private:
void StartAccept();
private:
boost::asio::io_context m_ioContext;
boost::asio::ip::tcp::acceptor m_acceptor;
std::shared_ptr<ConnectionManager> m_connectionManager;
std::vector<std::thread> m_threadPool;
const PacketHandlerMapPtr m_requestHandlerMap;
};
} // namespace Socket
} // namespace SPTAG
#endif // _SPTAG_SOCKET_SERVER_H_

View File

@ -1,174 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#ifndef _SPTAG_SOCKET_SIMPLESERIALIZATION_H_
#define _SPTAG_SOCKET_SIMPLESERIALIZATION_H_
#include "inc/Core/CommonDataStructure.h"
#include <type_traits>
#include <cstdint>
#include <memory>
#include <cstring>
namespace SPTAG
{
namespace Socket
{
namespace SimpleSerialization
{
template<typename T>
inline std::uint8_t*
SimpleWriteBuffer(const T& p_val, std::uint8_t* p_buffer)
{
static_assert(std::is_fundamental<T>::value || std::is_enum<T>::value,
"Only applied for fundanmental type.");
*(reinterpret_cast<T*>(p_buffer)) = p_val;
return p_buffer + sizeof(T);
}
template<typename T>
inline const std::uint8_t*
SimpleReadBuffer(const std::uint8_t* p_buffer, T& p_val)
{
static_assert(std::is_fundamental<T>::value || std::is_enum<T>::value,
"Only applied for fundanmental type.");
p_val = *(reinterpret_cast<const T*>(p_buffer));
return p_buffer + sizeof(T);
}
template<typename T>
inline std::size_t
EstimateBufferSize(const T& p_val)
{
static_assert(std::is_fundamental<T>::value || std::is_enum<T>::value,
"Only applied for fundanmental type.");
return sizeof(T);
}
template<>
inline std::uint8_t*
SimpleWriteBuffer<std::string>(const std::string& p_val, std::uint8_t* p_buffer)
{
p_buffer = SimpleWriteBuffer(static_cast<std::uint32_t>(p_val.size()), p_buffer);
std::memcpy(p_buffer, p_val.c_str(), p_val.size());
return p_buffer + p_val.size();
}
template<>
inline const std::uint8_t*
SimpleReadBuffer<std::string>(const std::uint8_t* p_buffer, std::string& p_val)
{
p_val.clear();
std::uint32_t len = 0;
p_buffer = SimpleReadBuffer(p_buffer, len);
if (len > 0)
{
p_val.reserve(len);
p_val.assign(reinterpret_cast<const char*>(p_buffer), len);
}
return p_buffer + len;
}
template<>
inline std::size_t
EstimateBufferSize<std::string>(const std::string& p_val)
{
return sizeof(std::uint32_t) + p_val.size();
}
template<>
inline std::uint8_t*
SimpleWriteBuffer<ByteArray>(const ByteArray& p_val, std::uint8_t* p_buffer)
{
p_buffer = SimpleWriteBuffer(static_cast<std::uint32_t>(p_val.Length()), p_buffer);
std::memcpy(p_buffer, p_val.Data(), p_val.Length());
return p_buffer + p_val.Length();
}
template<>
inline const std::uint8_t*
SimpleReadBuffer<ByteArray>(const std::uint8_t* p_buffer, ByteArray& p_val)
{
p_val.Clear();
std::uint32_t len = 0;
p_buffer = SimpleReadBuffer(p_buffer, len);
if (len > 0)
{
p_val = ByteArray::Alloc(len);
std::memcpy(p_val.Data(), p_buffer, len);
}
return p_buffer + len;
}
template<>
inline std::size_t
EstimateBufferSize<ByteArray>(const ByteArray& p_val)
{
return sizeof(std::uint32_t) + p_val.Length();
}
template<typename T>
inline std::uint8_t*
SimpleWriteSharedPtrBuffer(const std::shared_ptr<T>& p_val, std::uint8_t* p_buffer)
{
if (nullptr == p_val)
{
return SimpleWriteBuffer(false, p_buffer);
}
p_buffer = SimpleWriteBuffer(true, p_buffer);
p_buffer = SimpleWriteBuffer(*p_val, p_buffer);
return p_buffer;
}
template<typename T>
inline const std::uint8_t*
SimpleReadSharedPtrBuffer(const std::uint8_t* p_buffer, std::shared_ptr<T>& p_val)
{
p_val.reset();
bool isNotNull = false;
p_buffer = SimpleReadBuffer(p_buffer, isNotNull);
if (isNotNull)
{
p_val.reset(new T);
p_buffer = SimpleReadBuffer(p_buffer, *p_val);
}
return p_buffer;
}
template<typename T>
inline std::size_t
EstimateSharedPtrBufferSize(const std::shared_ptr<T>& p_val)
{
return sizeof(bool) + (nullptr == p_val ? 0 : EstimateBufferSize(*p_val));
}
} // namespace SimpleSerialization
} // namespace SPTAG
} // namespace Socket
#endif // _SPTAG_SOCKET_SIMPLESERIALIZATION_H_

View File

@ -1,151 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_ALLOCATOR_H
#define ARROW_ALLOCATOR_H
#include <algorithm>
#include <cstddef>
#include <memory>
#include <utility>
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
namespace arrow {
/// \brief A STL allocator delegating allocations to a Arrow MemoryPool
template <class T>
class stl_allocator {
public:
using value_type = T;
using pointer = T*;
using const_pointer = const T*;
using reference = T&;
using const_reference = const T&;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
template <class U>
struct rebind {
using other = stl_allocator<U>;
};
/// \brief Construct an allocator from the default MemoryPool
stl_allocator() noexcept : pool_(default_memory_pool()) {}
/// \brief Construct an allocator from the given MemoryPool
explicit stl_allocator(MemoryPool* pool) noexcept : pool_(pool) {}
template <class U>
stl_allocator(const stl_allocator<U>& rhs) noexcept : pool_(rhs.pool_) {}
~stl_allocator() { pool_ = NULLPTR; }
pointer address(reference r) const noexcept { return std::addressof(r); }
const_pointer address(const_reference r) const noexcept { return std::addressof(r); }
pointer allocate(size_type n, const void* /*hint*/ = NULLPTR) {
uint8_t* data;
Status s = pool_->Allocate(n * sizeof(T), &data);
if (!s.ok()) throw std::bad_alloc();
return reinterpret_cast<pointer>(data);
}
void deallocate(pointer p, size_type n) {
pool_->Free(reinterpret_cast<uint8_t*>(p), n * sizeof(T));
}
size_type size_max() const noexcept { return size_type(-1) / sizeof(T); }
template <class U, class... Args>
void construct(U* p, Args&&... args) {
new (reinterpret_cast<void*>(p)) U(std::forward<Args>(args)...);
}
template <class U>
void destroy(U* p) {
p->~U();
}
MemoryPool* pool() const noexcept { return pool_; }
private:
MemoryPool* pool_;
};
/// \brief A MemoryPool implementation delegating allocations to a STL allocator
///
/// Note that STL allocators don't provide a resizing operation, and therefore
/// any buffer resizes will do a full reallocation and copy.
template <typename Allocator = std::allocator<uint8_t>>
class STLMemoryPool : public MemoryPool {
public:
/// \brief Construct a memory pool from the given allocator
explicit STLMemoryPool(const Allocator& alloc) : alloc_(alloc) {}
Status Allocate(int64_t size, uint8_t** out) override {
try {
*out = alloc_.allocate(size);
} catch (std::bad_alloc& e) {
return Status::OutOfMemory(e.what());
}
stats_.UpdateAllocatedBytes(size);
return Status::OK();
}
Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override {
uint8_t* old_ptr = *ptr;
try {
*ptr = alloc_.allocate(new_size);
} catch (std::bad_alloc& e) {
return Status::OutOfMemory(e.what());
}
memcpy(*ptr, old_ptr, std::min(old_size, new_size));
alloc_.deallocate(old_ptr, old_size);
stats_.UpdateAllocatedBytes(new_size - old_size);
return Status::OK();
}
void Free(uint8_t* buffer, int64_t size) override {
alloc_.deallocate(buffer, size);
stats_.UpdateAllocatedBytes(-size);
}
int64_t bytes_allocated() const override { return stats_.bytes_allocated(); }
int64_t max_memory() const override { return stats_.max_memory(); }
private:
Allocator alloc_;
internal::MemoryPoolStats stats_;
};
template <class T1, class T2>
bool operator==(const stl_allocator<T1>& lhs, const stl_allocator<T2>& rhs) noexcept {
return lhs.pool() == rhs.pool();
}
template <class T1, class T2>
bool operator!=(const stl_allocator<T1>& lhs, const stl_allocator<T2>& rhs) noexcept {
return !(lhs == rhs);
}
} // namespace arrow
#endif // ARROW_ALLOCATOR_H

View File

@ -1,43 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Coarse public API while the library is in development
#ifndef ARROW_API_H
#define ARROW_API_H
#include "arrow/array.h" // IYWU pragma: export
#include "arrow/buffer.h" // IYWU pragma: export
#include "arrow/builder.h" // IYWU pragma: export
#include "arrow/compare.h" // IYWU pragma: export
#include "arrow/extension_type.h" // IYWU pragma: export
#include "arrow/memory_pool.h" // IYWU pragma: export
#include "arrow/pretty_print.h" // IYWU pragma: export
#include "arrow/record_batch.h" // IYWU pragma: export
#include "arrow/status.h" // IYWU pragma: export
#include "arrow/table.h" // IYWU pragma: export
#include "arrow/table_builder.h" // IYWU pragma: export
#include "arrow/tensor.h" // IYWU pragma: export
#include "arrow/type.h" // IYWU pragma: export
#include "arrow/util/config.h" // IYWU pragma: export
#include "arrow/util/key_value_metadata.h" // IWYU pragma: export
#include "arrow/visitor.h" // IYWU pragma: export
/// \brief Top-level namespace for Apache Arrow C++ API
namespace arrow {}
#endif // ARROW_API_H

File diff suppressed because it is too large Load Diff

View File

@ -1,175 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/array/builder_base.h"
namespace arrow {
namespace internal {
class ARROW_EXPORT AdaptiveIntBuilderBase : public ArrayBuilder {
public:
explicit AdaptiveIntBuilderBase(MemoryPool* pool);
/// \brief Append multiple nulls
/// \param[in] length the number of nulls to append
Status AppendNulls(int64_t length) final {
ARROW_RETURN_NOT_OK(CommitPendingData());
ARROW_RETURN_NOT_OK(Reserve(length));
memset(data_->mutable_data() + length_ * int_size_, 0, int_size_ * length);
UnsafeSetNull(length);
return Status::OK();
}
Status AppendNull() final {
pending_data_[pending_pos_] = 0;
pending_valid_[pending_pos_] = 0;
pending_has_nulls_ = true;
++pending_pos_;
if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
return CommitPendingData();
}
return Status::OK();
}
void Reset() override;
Status Resize(int64_t capacity) override;
protected:
virtual Status CommitPendingData() = 0;
std::shared_ptr<ResizableBuffer> data_;
uint8_t* raw_data_;
uint8_t int_size_;
static constexpr int32_t pending_size_ = 1024;
uint8_t pending_valid_[pending_size_];
uint64_t pending_data_[pending_size_];
int32_t pending_pos_;
bool pending_has_nulls_;
};
} // namespace internal
class ARROW_EXPORT AdaptiveUIntBuilder : public internal::AdaptiveIntBuilderBase {
public:
explicit AdaptiveUIntBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
using ArrayBuilder::Advance;
using internal::AdaptiveIntBuilderBase::Reset;
/// Scalar append
Status Append(const uint64_t val) {
pending_data_[pending_pos_] = val;
pending_valid_[pending_pos_] = 1;
++pending_pos_;
if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
return CommitPendingData();
}
return Status::OK();
}
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous C array of values
/// \param[in] length the number of values to append
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
/// indicates a valid (non-null) value
/// \return Status
Status AppendValues(const uint64_t* values, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
protected:
Status CommitPendingData() override;
Status ExpandIntSize(uint8_t new_int_size);
Status AppendValuesInternal(const uint64_t* values, int64_t length,
const uint8_t* valid_bytes);
template <typename new_type, typename old_type>
typename std::enable_if<sizeof(old_type) >= sizeof(new_type), Status>::type
ExpandIntSizeInternal();
#define __LESS(a, b) (a) < (b)
template <typename new_type, typename old_type>
typename std::enable_if<__LESS(sizeof(old_type), sizeof(new_type)), Status>::type
ExpandIntSizeInternal();
#undef __LESS
template <typename new_type>
Status ExpandIntSizeN();
};
class ARROW_EXPORT AdaptiveIntBuilder : public internal::AdaptiveIntBuilderBase {
public:
explicit AdaptiveIntBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
using ArrayBuilder::Advance;
using internal::AdaptiveIntBuilderBase::Reset;
/// Scalar append
Status Append(const int64_t val) {
auto v = static_cast<uint64_t>(val);
pending_data_[pending_pos_] = v;
pending_valid_[pending_pos_] = 1;
++pending_pos_;
if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
return CommitPendingData();
}
return Status::OK();
}
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous C array of values
/// \param[in] length the number of values to append
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
/// indicates a valid (non-null) value
/// \return Status
Status AppendValues(const int64_t* values, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
protected:
Status CommitPendingData() override;
Status ExpandIntSize(uint8_t new_int_size);
Status AppendValuesInternal(const int64_t* values, int64_t length,
const uint8_t* valid_bytes);
template <typename new_type, typename old_type>
typename std::enable_if<sizeof(old_type) >= sizeof(new_type), Status>::type
ExpandIntSizeInternal();
#define __LESS(a, b) (a) < (b)
template <typename new_type, typename old_type>
typename std::enable_if<__LESS(sizeof(old_type), sizeof(new_type)), Status>::type
ExpandIntSizeInternal();
#undef __LESS
template <typename new_type>
Status ExpandIntSizeN();
};
} // namespace arrow

View File

@ -1,219 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <algorithm> // IWYU pragma: keep
#include <cstdint>
#include <cstring>
#include <limits>
#include <memory>
#include <type_traits>
#include <utility>
#include <vector>
#include "arrow/buffer-builder.h"
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/macros.h"
#include "arrow/util/type_traits.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
struct ArrayData;
class MemoryPool;
constexpr int64_t kMinBuilderCapacity = 1 << 5;
constexpr int64_t kListMaximumElements = std::numeric_limits<int32_t>::max() - 1;
/// Base class for all data array builders.
///
/// This class provides a facilities for incrementally building the null bitmap
/// (see Append methods) and as a side effect the current number of slots and
/// the null count.
///
/// \note Users are expected to use builders as one of the concrete types below.
/// For example, ArrayBuilder* pointing to BinaryBuilder should be downcast before use.
class ARROW_EXPORT ArrayBuilder {
public:
explicit ArrayBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: type_(type), pool_(pool), null_bitmap_builder_(pool) {}
virtual ~ArrayBuilder() = default;
/// For nested types. Since the objects are owned by this class instance, we
/// skip shared pointers and just return a raw pointer
ArrayBuilder* child(int i) { return children_[i].get(); }
int num_children() const { return static_cast<int>(children_.size()); }
int64_t length() const { return length_; }
int64_t null_count() const { return null_count_; }
int64_t capacity() const { return capacity_; }
/// \brief Ensure that enough memory has been allocated to fit the indicated
/// number of total elements in the builder, including any that have already
/// been appended. Does not account for reallocations that may be due to
/// variable size data, like binary values. To make space for incremental
/// appends, use Reserve instead.
///
/// \param[in] capacity the minimum number of total array values to
/// accommodate. Must be greater than the current capacity.
/// \return Status
virtual Status Resize(int64_t capacity);
/// \brief Ensure that there is enough space allocated to add the indicated
/// number of elements without any further calls to Resize. Overallocation is
/// used in order to minimize the impact of incremental Reserve() calls.
///
/// \param[in] additional_capacity the number of additional array values
/// \return Status
Status Reserve(int64_t additional_capacity) {
auto current_capacity = capacity();
auto min_capacity = length() + additional_capacity;
if (min_capacity <= current_capacity) return Status::OK();
// leave growth factor up to BufferBuilder
auto new_capacity = BufferBuilder::GrowByFactor(current_capacity, min_capacity);
return Resize(new_capacity);
}
/// Reset the builder.
virtual void Reset();
virtual Status AppendNull() = 0;
virtual Status AppendNulls(int64_t length) = 0;
/// For cases where raw data was memcpy'd into the internal buffers, allows us
/// to advance the length of the builder. It is your responsibility to use
/// this function responsibly.
Status Advance(int64_t elements);
/// \brief Return result of builder as an internal generic ArrayData
/// object. Resets builder except for dictionary builder
///
/// \param[out] out the finalized ArrayData object
/// \return Status
virtual Status FinishInternal(std::shared_ptr<ArrayData>* out) = 0;
/// \brief Return result of builder as an Array object.
///
/// The builder is reset except for DictionaryBuilder.
///
/// \param[out] out the finalized Array object
/// \return Status
Status Finish(std::shared_ptr<Array>* out);
std::shared_ptr<DataType> type() const { return type_; }
protected:
/// Append to null bitmap
Status AppendToBitmap(bool is_valid);
/// Vector append. Treat each zero byte as a null. If valid_bytes is null
/// assume all of length bits are valid.
Status AppendToBitmap(const uint8_t* valid_bytes, int64_t length);
/// Uniform append. Append N times the same validity bit.
Status AppendToBitmap(int64_t num_bits, bool value);
/// Set the next length bits to not null (i.e. valid).
Status SetNotNull(int64_t length);
// Unsafe operations (don't check capacity/don't resize)
void UnsafeAppendNull() { UnsafeAppendToBitmap(false); }
// Append to null bitmap, update the length
void UnsafeAppendToBitmap(bool is_valid) {
null_bitmap_builder_.UnsafeAppend(is_valid);
++length_;
if (!is_valid) ++null_count_;
}
// Vector append. Treat each zero byte as a nullzero. If valid_bytes is null
// assume all of length bits are valid.
void UnsafeAppendToBitmap(const uint8_t* valid_bytes, int64_t length) {
if (valid_bytes == NULLPTR) {
return UnsafeSetNotNull(length);
}
null_bitmap_builder_.UnsafeAppend(valid_bytes, length);
length_ += length;
null_count_ = null_bitmap_builder_.false_count();
}
// Append the same validity value a given number of times.
void UnsafeAppendToBitmap(const int64_t num_bits, bool value) {
if (value) {
UnsafeSetNotNull(num_bits);
} else {
UnsafeSetNull(num_bits);
}
}
void UnsafeAppendToBitmap(const std::vector<bool>& is_valid);
// Set the next validity bits to not null (i.e. valid).
void UnsafeSetNotNull(int64_t length);
// Set the next validity bits to null (i.e. invalid).
void UnsafeSetNull(int64_t length);
static Status TrimBuffer(const int64_t bytes_filled, ResizableBuffer* buffer);
/// \brief Finish to an array of the specified ArrayType
template <typename ArrayType>
Status FinishTyped(std::shared_ptr<ArrayType>* out) {
std::shared_ptr<Array> out_untyped;
ARROW_RETURN_NOT_OK(Finish(&out_untyped));
*out = std::static_pointer_cast<ArrayType>(std::move(out_untyped));
return Status::OK();
}
static Status CheckCapacity(int64_t new_capacity, int64_t old_capacity) {
if (new_capacity < 0) {
return Status::Invalid("Resize capacity must be positive");
}
if (new_capacity < old_capacity) {
return Status::Invalid("Resize cannot downsize");
}
return Status::OK();
}
std::shared_ptr<DataType> type_;
MemoryPool* pool_;
TypedBufferBuilder<bool> null_bitmap_builder_;
int64_t null_count_ = 0;
// Array length, so far. Also, the index of the next element to be added
int64_t length_ = 0;
int64_t capacity_ = 0;
// Child value array builders. These are owned by this class
std::vector<std::shared_ptr<ArrayBuilder>> children_;
private:
ARROW_DISALLOW_COPY_AND_ASSIGN(ArrayBuilder);
};
} // namespace arrow

View File

@ -1,395 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "arrow/array.h"
#include "arrow/array/builder_base.h"
#include "arrow/buffer-builder.h"
#include "arrow/status.h"
#include "arrow/type_traits.h"
#include "arrow/util/macros.h"
#include "arrow/util/string_view.h" // IWYU pragma: export
namespace arrow {
constexpr int64_t kBinaryMemoryLimit = std::numeric_limits<int32_t>::max() - 1;
// ----------------------------------------------------------------------
// Binary and String
/// \class BinaryBuilder
/// \brief Builder class for variable-length binary data
class ARROW_EXPORT BinaryBuilder : public ArrayBuilder {
public:
explicit BinaryBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
BinaryBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool);
Status Append(const uint8_t* value, int32_t length) {
ARROW_RETURN_NOT_OK(Reserve(1));
ARROW_RETURN_NOT_OK(AppendNextOffset());
// Safety check for UBSAN.
if (ARROW_PREDICT_TRUE(length > 0)) {
ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length));
}
UnsafeAppendToBitmap(true);
return Status::OK();
}
Status AppendNulls(int64_t length) final {
const int64_t num_bytes = value_data_builder_.length();
if (ARROW_PREDICT_FALSE(num_bytes > kBinaryMemoryLimit)) {
return AppendOverflow(num_bytes);
}
ARROW_RETURN_NOT_OK(Reserve(length));
for (int64_t i = 0; i < length; ++i) {
offsets_builder_.UnsafeAppend(static_cast<int32_t>(num_bytes));
}
UnsafeAppendToBitmap(length, false);
return Status::OK();
}
Status AppendNull() final {
ARROW_RETURN_NOT_OK(AppendNextOffset());
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppendToBitmap(false);
return Status::OK();
}
Status Append(const char* value, int32_t length) {
return Append(reinterpret_cast<const uint8_t*>(value), length);
}
Status Append(util::string_view value) {
return Append(value.data(), static_cast<int32_t>(value.size()));
}
/// \brief Append without checking capacity
///
/// Offsets and data should have been presized using Reserve() and
/// ReserveData(), respectively.
void UnsafeAppend(const uint8_t* value, int32_t length) {
UnsafeAppendNextOffset();
value_data_builder_.UnsafeAppend(value, length);
UnsafeAppendToBitmap(true);
}
void UnsafeAppend(const char* value, int32_t length) {
UnsafeAppend(reinterpret_cast<const uint8_t*>(value), length);
}
void UnsafeAppend(const std::string& value) {
UnsafeAppend(value.c_str(), static_cast<int32_t>(value.size()));
}
void UnsafeAppend(util::string_view value) {
UnsafeAppend(value.data(), static_cast<int32_t>(value.size()));
}
void UnsafeAppendNull() {
const int64_t num_bytes = value_data_builder_.length();
offsets_builder_.UnsafeAppend(static_cast<int32_t>(num_bytes));
UnsafeAppendToBitmap(false);
}
void Reset() override;
Status Resize(int64_t capacity) override;
/// \brief Ensures there is enough allocated capacity to append the indicated
/// number of bytes to the value data buffer without additional allocations
Status ReserveData(int64_t elements);
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<BinaryArray>* out) { return FinishTyped(out); }
/// \return size of values buffer so far
int64_t value_data_length() const { return value_data_builder_.length(); }
/// \return capacity of values buffer
int64_t value_data_capacity() const { return value_data_builder_.capacity(); }
/// Temporary access to a value.
///
/// This pointer becomes invalid on the next modifying operation.
const uint8_t* GetValue(int64_t i, int32_t* out_length) const;
/// Temporary access to a value.
///
/// This view becomes invalid on the next modifying operation.
util::string_view GetView(int64_t i) const;
protected:
TypedBufferBuilder<int32_t> offsets_builder_;
TypedBufferBuilder<uint8_t> value_data_builder_;
Status AppendOverflow(int64_t num_bytes);
Status AppendNextOffset() {
const int64_t num_bytes = value_data_builder_.length();
if (ARROW_PREDICT_FALSE(num_bytes > kBinaryMemoryLimit)) {
return AppendOverflow(num_bytes);
}
return offsets_builder_.Append(static_cast<int32_t>(num_bytes));
}
void UnsafeAppendNextOffset() {
const int64_t num_bytes = value_data_builder_.length();
offsets_builder_.UnsafeAppend(static_cast<int32_t>(num_bytes));
}
};
/// \class StringBuilder
/// \brief Builder class for UTF8 strings
class ARROW_EXPORT StringBuilder : public BinaryBuilder {
public:
using BinaryBuilder::BinaryBuilder;
explicit StringBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
using BinaryBuilder::Append;
using BinaryBuilder::Reset;
using BinaryBuilder::UnsafeAppend;
/// \brief Append a sequence of strings in one shot.
///
/// \param[in] values a vector of strings
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
/// indicates a valid (non-null) value
/// \return Status
Status AppendValues(const std::vector<std::string>& values,
const uint8_t* valid_bytes = NULLPTR);
/// \brief Append a sequence of nul-terminated strings in one shot.
/// If one of the values is NULL, it is processed as a null
/// value even if the corresponding valid_bytes entry is 1.
///
/// \param[in] values a contiguous C array of nul-terminated char *
/// \param[in] length the number of values to append
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
/// indicates a valid (non-null) value
/// \return Status
Status AppendValues(const char** values, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<StringArray>* out) { return FinishTyped(out); }
};
// ----------------------------------------------------------------------
// FixedSizeBinaryBuilder
class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {
public:
FixedSizeBinaryBuilder(const std::shared_ptr<DataType>& type,
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
Status Append(const uint8_t* value) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(value);
return Status::OK();
}
Status Append(const char* value) {
return Append(reinterpret_cast<const uint8_t*>(value));
}
Status Append(const util::string_view& view) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(view);
return Status::OK();
}
Status Append(const std::string& s) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(s);
return Status::OK();
}
template <size_t NBYTES>
Status Append(const std::array<uint8_t, NBYTES>& value) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(
util::string_view(reinterpret_cast<const char*>(value.data()), value.size()));
return Status::OK();
}
Status AppendValues(const uint8_t* data, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
Status AppendNull() final;
Status AppendNulls(int64_t length) final;
void UnsafeAppend(const uint8_t* value) {
UnsafeAppendToBitmap(true);
if (ARROW_PREDICT_TRUE(byte_width_ > 0)) {
byte_builder_.UnsafeAppend(value, byte_width_);
}
}
void UnsafeAppend(util::string_view value) {
#ifndef NDEBUG
CheckValueSize(static_cast<size_t>(value.size()));
#endif
UnsafeAppend(reinterpret_cast<const uint8_t*>(value.data()));
}
void UnsafeAppendNull() {
UnsafeAppendToBitmap(false);
byte_builder_.UnsafeAdvance(byte_width_);
}
void Reset() override;
Status Resize(int64_t capacity) override;
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<FixedSizeBinaryArray>* out) { return FinishTyped(out); }
/// \return size of values buffer so far
int64_t value_data_length() const { return byte_builder_.length(); }
int32_t byte_width() const { return byte_width_; }
/// Temporary access to a value.
///
/// This pointer becomes invalid on the next modifying operation.
const uint8_t* GetValue(int64_t i) const;
/// Temporary access to a value.
///
/// This view becomes invalid on the next modifying operation.
util::string_view GetView(int64_t i) const;
protected:
int32_t byte_width_;
BufferBuilder byte_builder_;
/// Temporary access to a value.
///
/// This pointer becomes invalid on the next modifying operation.
uint8_t* GetMutableValue(int64_t i) {
uint8_t* data_ptr = byte_builder_.mutable_data();
return data_ptr + i * byte_width_;
}
#ifndef NDEBUG
void CheckValueSize(int64_t size);
#endif
};
// ----------------------------------------------------------------------
// Chunked builders: build a sequence of BinaryArray or StringArray that are
// limited to a particular size (to the upper limit of 2GB)
namespace internal {
class ARROW_EXPORT ChunkedBinaryBuilder {
public:
ChunkedBinaryBuilder(int32_t max_chunk_value_length,
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
ChunkedBinaryBuilder(int32_t max_chunk_value_length, int32_t max_chunk_length,
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
virtual ~ChunkedBinaryBuilder() = default;
Status Append(const uint8_t* value, int32_t length) {
if (ARROW_PREDICT_FALSE(length + builder_->value_data_length() >
max_chunk_value_length_)) {
if (builder_->value_data_length() == 0) {
// The current item is larger than max_chunk_size_;
// this chunk will be oversize and hold *only* this item
ARROW_RETURN_NOT_OK(builder_->Append(value, length));
return NextChunk();
}
// The current item would cause builder_->value_data_length() to exceed
// max_chunk_size_, so finish this chunk and append the current item to the next
// chunk
ARROW_RETURN_NOT_OK(NextChunk());
return Append(value, length);
}
if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) {
// The current item would cause builder_->value_data_length() to exceed
// max_chunk_size_, so finish this chunk and append the current item to the next
// chunk
ARROW_RETURN_NOT_OK(NextChunk());
}
return builder_->Append(value, length);
}
Status Append(const util::string_view& value) {
return Append(reinterpret_cast<const uint8_t*>(value.data()),
static_cast<int32_t>(value.size()));
}
Status AppendNull() {
if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) {
ARROW_RETURN_NOT_OK(NextChunk());
}
return builder_->AppendNull();
}
Status Reserve(int64_t values);
virtual Status Finish(ArrayVector* out);
protected:
Status NextChunk();
// maximum total character data size per chunk
int64_t max_chunk_value_length_;
// maximum elements allowed per chunk
int64_t max_chunk_length_ = kListMaximumElements;
// when Reserve() would cause builder_ to exceed its max_chunk_length_,
// add to extra_capacity_ instead and wait to reserve until the next chunk
int64_t extra_capacity_ = 0;
std::unique_ptr<BinaryBuilder> builder_;
std::vector<std::shared_ptr<Array>> chunks_;
};
class ARROW_EXPORT ChunkedStringBuilder : public ChunkedBinaryBuilder {
public:
using ChunkedBinaryBuilder::ChunkedBinaryBuilder;
Status Finish(ArrayVector* out) override;
};
} // namespace internal
} // namespace arrow

View File

@ -1,53 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/array/builder_base.h"
#include "arrow/array/builder_binary.h"
namespace arrow {
class Decimal128;
class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder {
public:
explicit Decimal128Builder(const std::shared_ptr<DataType>& type,
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
using FixedSizeBinaryBuilder::Append;
using FixedSizeBinaryBuilder::AppendValues;
using FixedSizeBinaryBuilder::Reset;
Status Append(Decimal128 val);
void UnsafeAppend(Decimal128 val);
void UnsafeAppend(util::string_view val);
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<Decimal128Array>* out) { return FinishTyped(out); }
};
using DecimalBuilder = Decimal128Builder;
} // namespace arrow

View File

@ -1,369 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <algorithm>
#include <memory>
#include "arrow/array/builder_adaptive.h" // IWYU pragma: export
#include "arrow/array/builder_base.h" // IWYU pragma: export
#include "arrow/array.h"
namespace arrow {
// ----------------------------------------------------------------------
// Dictionary builder
namespace internal {
template <typename T>
struct DictionaryScalar {
using type = typename T::c_type;
};
template <>
struct DictionaryScalar<BinaryType> {
using type = util::string_view;
};
template <>
struct DictionaryScalar<StringType> {
using type = util::string_view;
};
template <>
struct DictionaryScalar<FixedSizeBinaryType> {
using type = util::string_view;
};
class ARROW_EXPORT DictionaryMemoTable {
public:
explicit DictionaryMemoTable(const std::shared_ptr<DataType>& type);
explicit DictionaryMemoTable(const std::shared_ptr<Array>& dictionary);
~DictionaryMemoTable();
int32_t GetOrInsert(const bool& value);
int32_t GetOrInsert(const int8_t& value);
int32_t GetOrInsert(const int16_t& value);
int32_t GetOrInsert(const int32_t& value);
int32_t GetOrInsert(const int64_t& value);
int32_t GetOrInsert(const uint8_t& value);
int32_t GetOrInsert(const uint16_t& value);
int32_t GetOrInsert(const uint32_t& value);
int32_t GetOrInsert(const uint64_t& value);
int32_t GetOrInsert(const float& value);
int32_t GetOrInsert(const double& value);
int32_t GetOrInsert(const util::string_view& value);
Status GetArrayData(MemoryPool* pool, int64_t start_offset,
std::shared_ptr<ArrayData>* out);
int32_t size() const;
private:
class DictionaryMemoTableImpl;
std::unique_ptr<DictionaryMemoTableImpl> impl_;
};
} // namespace internal
/// \brief Array builder for created encoded DictionaryArray from
/// dense array
///
/// Unlike other builders, dictionary builder does not completely
/// reset the state on Finish calls. The arrays built after the
/// initial Finish call will reuse the previously created encoding and
/// build a delta dictionary when new terms occur.
///
/// data
template <typename T>
class DictionaryBuilder : public ArrayBuilder {
public:
using Scalar = typename internal::DictionaryScalar<T>::type;
// WARNING: the type given below is the value type, not the DictionaryType.
// The DictionaryType is instantiated on the Finish() call.
template <typename T1 = T>
DictionaryBuilder(
typename std::enable_if<!std::is_base_of<FixedSizeBinaryType, T1>::value,
const std::shared_ptr<DataType>&>::type type,
MemoryPool* pool)
: ArrayBuilder(type, pool),
memo_table_(new internal::DictionaryMemoTable(type)),
delta_offset_(0),
byte_width_(-1),
values_builder_(pool) {}
template <typename T1 = T>
explicit DictionaryBuilder(
typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
const std::shared_ptr<DataType>&>::type type,
MemoryPool* pool)
: ArrayBuilder(type, pool),
memo_table_(new internal::DictionaryMemoTable(type)),
delta_offset_(0),
byte_width_(static_cast<const T1&>(*type).byte_width()),
values_builder_(pool) {}
template <typename T1 = T>
explicit DictionaryBuilder(
typename std::enable_if<TypeTraits<T1>::is_parameter_free, MemoryPool*>::type pool)
: DictionaryBuilder<T1>(TypeTraits<T1>::type_singleton(), pool) {}
DictionaryBuilder(const std::shared_ptr<Array>& dictionary, MemoryPool* pool)
: ArrayBuilder(dictionary->type(), pool),
memo_table_(new internal::DictionaryMemoTable(dictionary)),
delta_offset_(0),
byte_width_(-1),
values_builder_(pool) {}
~DictionaryBuilder() override = default;
/// \brief Append a scalar value
Status Append(const Scalar& value) {
ARROW_RETURN_NOT_OK(Reserve(1));
auto memo_index = memo_table_->GetOrInsert(value);
ARROW_RETURN_NOT_OK(values_builder_.Append(memo_index));
length_ += 1;
return Status::OK();
}
/// \brief Append a fixed-width string (only for FixedSizeBinaryType)
template <typename T1 = T>
Status Append(typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
const uint8_t*>::type value) {
return Append(util::string_view(reinterpret_cast<const char*>(value), byte_width_));
}
/// \brief Append a fixed-width string (only for FixedSizeBinaryType)
template <typename T1 = T>
Status Append(typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
const char*>::type value) {
return Append(util::string_view(value, byte_width_));
}
/// \brief Append a scalar null value
Status AppendNull() final {
length_ += 1;
null_count_ += 1;
return values_builder_.AppendNull();
}
Status AppendNulls(int64_t length) final {
length_ += length;
null_count_ += length;
return values_builder_.AppendNulls(length);
}
/// \brief Append a whole dense array to the builder
template <typename T1 = T>
Status AppendArray(
typename std::enable_if<!std::is_base_of<FixedSizeBinaryType, T1>::value,
const Array&>::type array) {
using ArrayType = typename TypeTraits<T>::ArrayType;
const auto& concrete_array = static_cast<const ArrayType&>(array);
for (int64_t i = 0; i < array.length(); i++) {
if (array.IsNull(i)) {
ARROW_RETURN_NOT_OK(AppendNull());
} else {
ARROW_RETURN_NOT_OK(Append(concrete_array.GetView(i)));
}
}
return Status::OK();
}
template <typename T1 = T>
Status AppendArray(
typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
const Array&>::type array) {
if (!type_->Equals(*array.type())) {
return Status::Invalid(
"Cannot append FixedSizeBinary array with non-matching type");
}
const auto& concrete_array = static_cast<const FixedSizeBinaryArray&>(array);
for (int64_t i = 0; i < array.length(); i++) {
if (array.IsNull(i)) {
ARROW_RETURN_NOT_OK(AppendNull());
} else {
ARROW_RETURN_NOT_OK(Append(concrete_array.GetValue(i)));
}
}
return Status::OK();
}
void Reset() override {
ArrayBuilder::Reset();
values_builder_.Reset();
memo_table_.reset(new internal::DictionaryMemoTable(type_));
delta_offset_ = 0;
}
Status Resize(int64_t capacity) override {
ARROW_RETURN_NOT_OK(CheckCapacity(capacity, capacity_));
capacity = std::max(capacity, kMinBuilderCapacity);
if (capacity_ == 0) {
// Initialize hash table
// XXX should we let the user pass additional size heuristics?
delta_offset_ = 0;
}
ARROW_RETURN_NOT_OK(values_builder_.Resize(capacity));
capacity_ = values_builder_.capacity();
return Status::OK();
}
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
// Finalize indices array
ARROW_RETURN_NOT_OK(values_builder_.FinishInternal(out));
// Generate dictionary array from hash table contents
std::shared_ptr<ArrayData> dictionary_data;
ARROW_RETURN_NOT_OK(
memo_table_->GetArrayData(pool_, delta_offset_, &dictionary_data));
// Set type of array data to the right dictionary type
(*out)->type = dictionary((*out)->type, type_);
(*out)->dictionary = MakeArray(dictionary_data);
// Update internals for further uses of this DictionaryBuilder
delta_offset_ = memo_table_->size();
values_builder_.Reset();
return Status::OK();
}
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<DictionaryArray>* out) { return FinishTyped(out); }
/// is the dictionary builder in the delta building mode
bool is_building_delta() { return delta_offset_ > 0; }
protected:
std::unique_ptr<internal::DictionaryMemoTable> memo_table_;
int32_t delta_offset_;
// Only used for FixedSizeBinaryType
int32_t byte_width_;
AdaptiveIntBuilder values_builder_;
};
template <>
class DictionaryBuilder<NullType> : public ArrayBuilder {
public:
DictionaryBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: ArrayBuilder(type, pool), values_builder_(pool) {}
explicit DictionaryBuilder(MemoryPool* pool)
: ArrayBuilder(null(), pool), values_builder_(pool) {}
DictionaryBuilder(const std::shared_ptr<Array>& dictionary, MemoryPool* pool)
: ArrayBuilder(dictionary->type(), pool), values_builder_(pool) {}
/// \brief Append a scalar null value
Status AppendNull() final {
length_ += 1;
null_count_ += 1;
return values_builder_.AppendNull();
}
Status AppendNulls(int64_t length) final {
length_ += length;
null_count_ += length;
return values_builder_.AppendNulls(length);
}
/// \brief Append a whole dense array to the builder
Status AppendArray(const Array& array) {
for (int64_t i = 0; i < array.length(); i++) {
ARROW_RETURN_NOT_OK(AppendNull());
}
return Status::OK();
}
Status Resize(int64_t capacity) override {
ARROW_RETURN_NOT_OK(CheckCapacity(capacity, capacity_));
capacity = std::max(capacity, kMinBuilderCapacity);
ARROW_RETURN_NOT_OK(values_builder_.Resize(capacity));
capacity_ = values_builder_.capacity();
return Status::OK();
}
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
std::shared_ptr<Array> dictionary = std::make_shared<NullArray>(0);
ARROW_RETURN_NOT_OK(values_builder_.FinishInternal(out));
(*out)->type = std::make_shared<DictionaryType>((*out)->type, type_);
(*out)->dictionary = dictionary;
return Status::OK();
}
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<DictionaryArray>* out) { return FinishTyped(out); }
protected:
AdaptiveIntBuilder values_builder_;
};
class ARROW_EXPORT BinaryDictionaryBuilder : public DictionaryBuilder<BinaryType> {
public:
using DictionaryBuilder::Append;
using DictionaryBuilder::DictionaryBuilder;
Status Append(const uint8_t* value, int32_t length) {
return Append(reinterpret_cast<const char*>(value), length);
}
Status Append(const char* value, int32_t length) {
return Append(util::string_view(value, length));
}
};
/// \brief Dictionary array builder with convenience methods for strings
class ARROW_EXPORT StringDictionaryBuilder : public DictionaryBuilder<StringType> {
public:
using DictionaryBuilder::Append;
using DictionaryBuilder::DictionaryBuilder;
Status Append(const uint8_t* value, int32_t length) {
return Append(reinterpret_cast<const char*>(value), length);
}
Status Append(const char* value, int32_t length) {
return Append(util::string_view(value, length));
}
};
} // namespace arrow

View File

@ -1,260 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <vector>
#include "arrow/array/builder_base.h"
#include "arrow/buffer-builder.h"
namespace arrow {
// ----------------------------------------------------------------------
// List builder
/// \class ListBuilder
/// \brief Builder class for variable-length list array value types
///
/// To use this class, you must append values to the child array builder and use
/// the Append function to delimit each distinct list value (once the values
/// have been appended to the child array) or use the bulk API to append
/// a sequence of offests and null values.
///
/// A note on types. Per arrow/type.h all types in the c++ implementation are
/// logical so even though this class always builds list array, this can
/// represent multiple different logical types. If no logical type is provided
/// at construction time, the class defaults to List<T> where t is taken from the
/// value_builder/values that the object is constructed with.
class ARROW_EXPORT ListBuilder : public ArrayBuilder {
public:
/// Use this constructor to incrementally build the value array along with offsets and
/// null bitmap.
ListBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& value_builder,
const std::shared_ptr<DataType>& type = NULLPTR);
Status Resize(int64_t capacity) override;
void Reset() override;
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<ListArray>* out) { return FinishTyped(out); }
/// \brief Vector append
///
/// If passed, valid_bytes is of equal length to values, and any zero byte
/// will be considered as a null for that slot
Status AppendValues(const int32_t* offsets, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
/// \brief Start a new variable-length list slot
///
/// This function should be called before beginning to append elements to the
/// value builder
Status Append(bool is_valid = true);
Status AppendNull() final { return Append(false); }
Status AppendNulls(int64_t length) final;
ArrayBuilder* value_builder() const;
protected:
TypedBufferBuilder<int32_t> offsets_builder_;
std::shared_ptr<ArrayBuilder> value_builder_;
std::shared_ptr<Array> values_;
Status CheckNextOffset() const;
Status AppendNextOffset();
Status AppendNextOffset(int64_t num_repeats);
};
// ----------------------------------------------------------------------
// Map builder
/// \class MapBuilder
/// \brief Builder class for arrays of variable-size maps
///
/// To use this class, you must append values to the key and item array builders
/// and use the Append function to delimit each distinct map (once the keys and items
/// have been appended) or use the bulk API to append a sequence of offests and null
/// maps.
///
/// Key uniqueness and ordering are not validated.
class ARROW_EXPORT MapBuilder : public ArrayBuilder {
public:
/// Use this constructor to incrementally build the key and item arrays along with
/// offsets and null bitmap.
MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
const std::shared_ptr<ArrayBuilder>& item_builder,
const std::shared_ptr<DataType>& type);
/// Derive built type from key and item builders' types
MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
const std::shared_ptr<ArrayBuilder>& item_builder, bool keys_sorted = false);
Status Resize(int64_t capacity) override;
void Reset() override;
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<MapArray>* out) { return FinishTyped(out); }
/// \brief Vector append
///
/// If passed, valid_bytes is of equal length to values, and any zero byte
/// will be considered as a null for that slot
Status AppendValues(const int32_t* offsets, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
/// \brief Start a new variable-length map slot
///
/// This function should be called before beginning to append elements to the
/// key and value builders
Status Append();
Status AppendNull() final;
Status AppendNulls(int64_t length) final;
ArrayBuilder* key_builder() const { return key_builder_.get(); }
ArrayBuilder* item_builder() const { return item_builder_.get(); }
protected:
std::shared_ptr<ListBuilder> list_builder_;
std::shared_ptr<ArrayBuilder> key_builder_;
std::shared_ptr<ArrayBuilder> item_builder_;
};
// ----------------------------------------------------------------------
// FixedSizeList builder
/// \class FixedSizeListBuilder
/// \brief Builder class for fixed-length list array value types
class ARROW_EXPORT FixedSizeListBuilder : public ArrayBuilder {
public:
FixedSizeListBuilder(MemoryPool* pool,
std::shared_ptr<ArrayBuilder> const& value_builder,
int32_t list_size);
FixedSizeListBuilder(MemoryPool* pool,
std::shared_ptr<ArrayBuilder> const& value_builder,
const std::shared_ptr<DataType>& type);
Status Resize(int64_t capacity) override;
void Reset() override;
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<FixedSizeListArray>* out) { return FinishTyped(out); }
/// \brief Append a valid fixed length list.
///
/// This function affects only the validity bitmap; the child values must be appended
/// using the child array builder.
Status Append();
/// \brief Vector append
///
/// If passed, valid_bytes wil be read and any zero byte
/// will cause the corresponding slot to be null
///
/// This function affects only the validity bitmap; the child values must be appended
/// using the child array builder. This includes appending nulls for null lists.
/// XXX this restriction is confusing, should this method be omitted?
Status AppendValues(int64_t length, const uint8_t* valid_bytes = NULLPTR);
/// \brief Append a null fixed length list.
///
/// The child array builder will have the approriate number of nulls appended
/// automatically.
Status AppendNull() final;
/// \brief Append length null fixed length lists.
///
/// The child array builder will have the approriate number of nulls appended
/// automatically.
Status AppendNulls(int64_t length) final;
ArrayBuilder* value_builder() const { return value_builder_.get(); }
protected:
const int32_t list_size_;
std::shared_ptr<ArrayBuilder> value_builder_;
};
// ----------------------------------------------------------------------
// Struct
// ---------------------------------------------------------------------------------
// StructArray builder
/// Append, Resize and Reserve methods are acting on StructBuilder.
/// Please make sure all these methods of all child-builders' are consistently
/// called to maintain data-structure consistency.
class ARROW_EXPORT StructBuilder : public ArrayBuilder {
public:
StructBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool,
std::vector<std::shared_ptr<ArrayBuilder>>&& field_builders);
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<StructArray>* out) { return FinishTyped(out); }
/// Null bitmap is of equal length to every child field, and any zero byte
/// will be considered as a null for that field, but users must using app-
/// end methods or advance methods of the child builders' independently to
/// insert data.
Status AppendValues(int64_t length, const uint8_t* valid_bytes) {
ARROW_RETURN_NOT_OK(Reserve(length));
UnsafeAppendToBitmap(valid_bytes, length);
return Status::OK();
}
/// Append an element to the Struct. All child-builders' Append method must
/// be called independently to maintain data-structure consistency.
Status Append(bool is_valid = true) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppendToBitmap(is_valid);
return Status::OK();
}
Status AppendNull() final { return Append(false); }
Status AppendNulls(int64_t length) final;
void Reset() override;
ArrayBuilder* field_builder(int i) const { return children_[i].get(); }
int num_fields() const { return static_cast<int>(children_.size()); }
};
} // namespace arrow

View File

@ -1,429 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <vector>
#include "arrow/array.h"
#include "arrow/array/builder_base.h"
#include "arrow/type.h"
namespace arrow {
class ARROW_EXPORT NullBuilder : public ArrayBuilder {
public:
explicit NullBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
: ArrayBuilder(null(), pool) {}
/// \brief Append the specified number of null elements
Status AppendNulls(int64_t length) final {
if (length < 0) return Status::Invalid("length must be positive");
null_count_ += length;
length_ += length;
return Status::OK();
}
/// \brief Append a single null element
Status AppendNull() final { return AppendNulls(1); }
Status Append(std::nullptr_t) { return AppendNull(); }
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<NullArray>* out) { return FinishTyped(out); }
};
/// Base class for all Builders that emit an Array of a scalar numerical type.
template <typename T>
class NumericBuilder : public ArrayBuilder {
public:
using value_type = typename T::c_type;
using ArrayType = typename TypeTraits<T>::ArrayType;
using ArrayBuilder::ArrayBuilder;
template <typename T1 = T>
explicit NumericBuilder(
typename std::enable_if<TypeTraits<T1>::is_parameter_free, MemoryPool*>::type pool
ARROW_MEMORY_POOL_DEFAULT)
: ArrayBuilder(TypeTraits<T1>::type_singleton(), pool) {}
/// Append a single scalar and increase the size if necessary.
Status Append(const value_type val) {
ARROW_RETURN_NOT_OK(ArrayBuilder::Reserve(1));
UnsafeAppend(val);
return Status::OK();
}
/// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory
/// The memory at the corresponding data slot is set to 0 to prevent
/// uninitialized memory access
Status AppendNulls(int64_t length) final {
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(length, static_cast<value_type>(0));
UnsafeSetNull(length);
return Status::OK();
}
/// \brief Append a single null element
Status AppendNull() final {
ARROW_RETURN_NOT_OK(Reserve(1));
data_builder_.UnsafeAppend(static_cast<value_type>(0));
UnsafeAppendToBitmap(false);
return Status::OK();
}
value_type GetValue(int64_t index) const { return data_builder_.data()[index]; }
void Reset() override { data_builder_.Reset(); }
Status Resize(int64_t capacity) override {
ARROW_RETURN_NOT_OK(CheckCapacity(capacity, capacity_));
capacity = std::max(capacity, kMinBuilderCapacity);
ARROW_RETURN_NOT_OK(data_builder_.Resize(capacity));
return ArrayBuilder::Resize(capacity);
}
value_type operator[](int64_t index) const { return GetValue(index); }
value_type& operator[](int64_t index) {
return reinterpret_cast<value_type*>(data_builder_.mutable_data())[index];
}
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous C array of values
/// \param[in] length the number of values to append
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
/// indicates a valid (non-null) value
/// \return Status
Status AppendValues(const value_type* values, int64_t length,
const uint8_t* valid_bytes = NULLPTR) {
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(values, length);
// length_ is update by these
ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, length);
return Status::OK();
}
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous C array of values
/// \param[in] length the number of values to append
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
/// (0). Equal in length to values
/// \return Status
Status AppendValues(const value_type* values, int64_t length,
const std::vector<bool>& is_valid) {
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(values, length);
// length_ is update by these
ArrayBuilder::UnsafeAppendToBitmap(is_valid);
return Status::OK();
}
/// \brief Append a sequence of elements in one shot
/// \param[in] values a std::vector of values
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
/// (0). Equal in length to values
/// \return Status
Status AppendValues(const std::vector<value_type>& values,
const std::vector<bool>& is_valid) {
return AppendValues(values.data(), static_cast<int64_t>(values.size()), is_valid);
}
/// \brief Append a sequence of elements in one shot
/// \param[in] values a std::vector of values
/// \return Status
Status AppendValues(const std::vector<value_type>& values) {
return AppendValues(values.data(), static_cast<int64_t>(values.size()));
}
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
std::shared_ptr<Buffer> data, null_bitmap;
ARROW_RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
ARROW_RETURN_NOT_OK(data_builder_.Finish(&data));
*out = ArrayData::Make(type_, length_, {null_bitmap, data}, null_count_);
capacity_ = length_ = null_count_ = 0;
return Status::OK();
}
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<ArrayType>* out) { return FinishTyped(out); }
/// \brief Append a sequence of elements in one shot
/// \param[in] values_begin InputIterator to the beginning of the values
/// \param[in] values_end InputIterator pointing to the end of the values
/// \return Status
template <typename ValuesIter>
Status AppendValues(ValuesIter values_begin, ValuesIter values_end) {
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(values_begin, values_end);
// this updates the length_
UnsafeSetNotNull(length);
return Status::OK();
}
/// \brief Append a sequence of elements in one shot, with a specified nullmap
/// \param[in] values_begin InputIterator to the beginning of the values
/// \param[in] values_end InputIterator pointing to the end of the values
/// \param[in] valid_begin InputIterator with elements indication valid(1)
/// or null(0) values.
/// \return Status
template <typename ValuesIter, typename ValidIter>
typename std::enable_if<!std::is_pointer<ValidIter>::value, Status>::type AppendValues(
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
static_assert(!internal::is_null_pointer<ValidIter>::value,
"Don't pass a NULLPTR directly as valid_begin, use the 2-argument "
"version instead");
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(values_begin, values_end);
null_bitmap_builder_.UnsafeAppend<true>(
length, [&valid_begin]() -> bool { return *valid_begin++; });
length_ = null_bitmap_builder_.length();
null_count_ = null_bitmap_builder_.false_count();
return Status::OK();
}
// Same as above, with a pointer type ValidIter
template <typename ValuesIter, typename ValidIter>
typename std::enable_if<std::is_pointer<ValidIter>::value, Status>::type AppendValues(
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(values_begin, values_end);
// this updates the length_
if (valid_begin == NULLPTR) {
UnsafeSetNotNull(length);
} else {
null_bitmap_builder_.UnsafeAppend<true>(
length, [&valid_begin]() -> bool { return *valid_begin++; });
length_ = null_bitmap_builder_.length();
null_count_ = null_bitmap_builder_.false_count();
}
return Status::OK();
}
/// Append a single scalar under the assumption that the underlying Buffer is
/// large enough.
///
/// This method does not capacity-check; make sure to call Reserve
/// beforehand.
void UnsafeAppend(const value_type val) {
ArrayBuilder::UnsafeAppendToBitmap(true);
data_builder_.UnsafeAppend(val);
}
void UnsafeAppendNull() {
ArrayBuilder::UnsafeAppendToBitmap(false);
data_builder_.UnsafeAppend(0);
}
protected:
TypedBufferBuilder<value_type> data_builder_;
};
// Builders
using UInt8Builder = NumericBuilder<UInt8Type>;
using UInt16Builder = NumericBuilder<UInt16Type>;
using UInt32Builder = NumericBuilder<UInt32Type>;
using UInt64Builder = NumericBuilder<UInt64Type>;
using Int8Builder = NumericBuilder<Int8Type>;
using Int16Builder = NumericBuilder<Int16Type>;
using Int32Builder = NumericBuilder<Int32Type>;
using Int64Builder = NumericBuilder<Int64Type>;
using HalfFloatBuilder = NumericBuilder<HalfFloatType>;
using FloatBuilder = NumericBuilder<FloatType>;
using DoubleBuilder = NumericBuilder<DoubleType>;
class ARROW_EXPORT BooleanBuilder : public ArrayBuilder {
public:
using value_type = bool;
explicit BooleanBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
explicit BooleanBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool);
/// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory
Status AppendNulls(int64_t length) final {
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend(length, false);
UnsafeSetNull(length);
return Status::OK();
}
Status AppendNull() final {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppendNull();
return Status::OK();
}
/// Scalar append
Status Append(const bool val) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(val);
return Status::OK();
}
Status Append(const uint8_t val) { return Append(val != 0); }
/// Scalar append, without checking for capacity
void UnsafeAppend(const bool val) {
data_builder_.UnsafeAppend(val);
UnsafeAppendToBitmap(true);
}
void UnsafeAppendNull() {
data_builder_.UnsafeAppend(false);
UnsafeAppendToBitmap(false);
}
void UnsafeAppend(const uint8_t val) { UnsafeAppend(val != 0); }
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous array of bytes (non-zero is 1)
/// \param[in] length the number of values to append
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
/// indicates a valid (non-null) value
/// \return Status
Status AppendValues(const uint8_t* values, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
/// \brief Append a sequence of elements in one shot
/// \param[in] values a contiguous C array of values
/// \param[in] length the number of values to append
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
/// (0). Equal in length to values
/// \return Status
Status AppendValues(const uint8_t* values, int64_t length,
const std::vector<bool>& is_valid);
/// \brief Append a sequence of elements in one shot
/// \param[in] values a std::vector of bytes
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
/// (0). Equal in length to values
/// \return Status
Status AppendValues(const std::vector<uint8_t>& values,
const std::vector<bool>& is_valid);
/// \brief Append a sequence of elements in one shot
/// \param[in] values a std::vector of bytes
/// \return Status
Status AppendValues(const std::vector<uint8_t>& values);
/// \brief Append a sequence of elements in one shot
/// \param[in] values an std::vector<bool> indicating true (1) or false
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
/// (0). Equal in length to values
/// \return Status
Status AppendValues(const std::vector<bool>& values, const std::vector<bool>& is_valid);
/// \brief Append a sequence of elements in one shot
/// \param[in] values an std::vector<bool> indicating true (1) or false
/// \return Status
Status AppendValues(const std::vector<bool>& values);
/// \brief Append a sequence of elements in one shot
/// \param[in] values_begin InputIterator to the beginning of the values
/// \param[in] values_end InputIterator pointing to the end of the values
/// or null(0) values
/// \return Status
template <typename ValuesIter>
Status AppendValues(ValuesIter values_begin, ValuesIter values_end) {
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend<false>(
length, [&values_begin]() -> bool { return *values_begin++; });
// this updates length_
UnsafeSetNotNull(length);
return Status::OK();
}
/// \brief Append a sequence of elements in one shot, with a specified nullmap
/// \param[in] values_begin InputIterator to the beginning of the values
/// \param[in] values_end InputIterator pointing to the end of the values
/// \param[in] valid_begin InputIterator with elements indication valid(1)
/// or null(0) values
/// \return Status
template <typename ValuesIter, typename ValidIter>
typename std::enable_if<!std::is_pointer<ValidIter>::value, Status>::type AppendValues(
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
static_assert(!internal::is_null_pointer<ValidIter>::value,
"Don't pass a NULLPTR directly as valid_begin, use the 2-argument "
"version instead");
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend<false>(
length, [&values_begin]() -> bool { return *values_begin++; });
null_bitmap_builder_.UnsafeAppend<true>(
length, [&valid_begin]() -> bool { return *valid_begin++; });
length_ = null_bitmap_builder_.length();
null_count_ = null_bitmap_builder_.false_count();
return Status::OK();
}
// Same as above, for a pointer type ValidIter
template <typename ValuesIter, typename ValidIter>
typename std::enable_if<std::is_pointer<ValidIter>::value, Status>::type AppendValues(
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
ARROW_RETURN_NOT_OK(Reserve(length));
data_builder_.UnsafeAppend<false>(
length, [&values_begin]() -> bool { return *values_begin++; });
if (valid_begin == NULLPTR) {
UnsafeSetNotNull(length);
} else {
null_bitmap_builder_.UnsafeAppend<true>(
length, [&valid_begin]() -> bool { return *valid_begin++; });
}
length_ = null_bitmap_builder_.length();
null_count_ = null_bitmap_builder_.false_count();
return Status::OK();
}
Status AppendValues(int64_t length, bool value);
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<BooleanArray>* out) { return FinishTyped(out); }
void Reset() override;
Status Resize(int64_t capacity) override;
protected:
TypedBufferBuilder<bool> data_builder_;
};
} // namespace arrow

View File

@ -1,70 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Contains declarations of time related Arrow builder types.
#pragma once
#include <memory>
#include "arrow/array.h"
#include "arrow/array/builder_base.h"
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/buffer-builder.h"
#include "arrow/status.h"
#include "arrow/type_traits.h"
#include "arrow/util/macros.h"
namespace arrow {
class ARROW_EXPORT DayTimeIntervalBuilder : public ArrayBuilder {
public:
using DayMilliseconds = DayTimeIntervalType::DayMilliseconds;
explicit DayTimeIntervalBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
: DayTimeIntervalBuilder(day_time_interval(), pool) {}
DayTimeIntervalBuilder(std::shared_ptr<DataType> type,
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
: ArrayBuilder(type, pool),
builder_(fixed_size_binary(sizeof(DayMilliseconds)), pool) {}
void Reset() override { builder_.Reset(); }
Status Resize(int64_t capacity) override { return builder_.Resize(capacity); }
Status Append(DayMilliseconds day_millis) {
return builder_.Append(reinterpret_cast<uint8_t*>(&day_millis));
}
void UnsafeAppend(DayMilliseconds day_millis) {
builder_.UnsafeAppend(reinterpret_cast<uint8_t*>(&day_millis));
}
using ArrayBuilder::UnsafeAppendNull;
Status AppendNull() override { return builder_.AppendNull(); }
Status AppendNulls(int64_t length) override { return builder_.AppendNulls(length); }
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
auto result = builder_.FinishInternal(out);
if (*out != NULLPTR) {
(*out)->type = type();
}
return result;
}
private:
FixedSizeBinaryBuilder builder_;
};
} // namespace arrow

View File

@ -1,106 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "arrow/array.h"
#include "arrow/array/builder_base.h"
#include "arrow/buffer-builder.h"
namespace arrow {
/// \class DenseUnionBuilder
///
/// You need to call AppendChild for each of the children builders you want
/// to use. The function will return an int8_t, which is the type tag
/// associated with that child. You can then call Append with that tag
/// (followed by an append on the child builder) to add elements to
/// the union array.
///
/// You can either specify the type when the UnionBuilder is constructed
/// or let the UnionBuilder infer the type at runtime (by omitting the
/// type argument from the constructor).
///
/// This API is EXPERIMENTAL.
class ARROW_EXPORT DenseUnionBuilder : public ArrayBuilder {
public:
/// Use this constructor to incrementally build the union array along
/// with types, offsets, and null bitmap.
explicit DenseUnionBuilder(MemoryPool* pool,
const std::shared_ptr<DataType>& type = NULLPTR);
Status AppendNull() final {
ARROW_RETURN_NOT_OK(types_builder_.Append(0));
ARROW_RETURN_NOT_OK(offsets_builder_.Append(0));
return AppendToBitmap(false);
}
Status AppendNulls(int64_t length) final {
ARROW_RETURN_NOT_OK(types_builder_.Reserve(length));
ARROW_RETURN_NOT_OK(offsets_builder_.Reserve(length));
ARROW_RETURN_NOT_OK(Reserve(length));
for (int64_t i = 0; i < length; ++i) {
types_builder_.UnsafeAppend(0);
offsets_builder_.UnsafeAppend(0);
}
return AppendToBitmap(length, false);
}
/// \brief Append an element to the UnionArray. This must be followed
/// by an append to the appropriate child builder.
/// \param[in] type index of the child the value will be appended
/// \param[in] offset offset of the value in that child
Status Append(int8_t type, int32_t offset) {
ARROW_RETURN_NOT_OK(types_builder_.Append(type));
ARROW_RETURN_NOT_OK(offsets_builder_.Append(offset));
return AppendToBitmap(true);
}
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
/// \cond FALSE
using ArrayBuilder::Finish;
/// \endcond
Status Finish(std::shared_ptr<UnionArray>* out) { return FinishTyped(out); }
/// \brief Make a new child builder available to the UnionArray
///
/// \param[in] child the child builder
/// \param[in] field_name the name of the field in the union array type
/// if type inference is used
/// \return child index, which is the "type" argument that needs
/// to be passed to the "Append" method to add a new element to
/// the union array.
int8_t AppendChild(const std::shared_ptr<ArrayBuilder>& child,
const std::string& field_name = "") {
children_.push_back(child);
field_names_.push_back(field_name);
return static_cast<int8_t>(children_.size() - 1);
}
private:
TypedBufferBuilder<int8_t> types_builder_;
TypedBufferBuilder<int32_t> offsets_builder_;
std::vector<std::string> field_names_;
};
} // namespace arrow

View File

@ -1,39 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <vector>
#include "arrow/array.h"
#include "arrow/memory_pool.h"
#include "arrow/util/visibility.h"
namespace arrow {
/// \brief Concatenate arrays
///
/// \param[in] arrays a vector of arrays to be concatenated
/// \param[in] pool memory to store the result will be allocated from this memory pool
/// \param[out] out the resulting concatenated array
/// \return Status
ARROW_EXPORT
Status Concatenate(const ArrayVector& arrays, MemoryPool* pool,
std::shared_ptr<Array>* out);
} // namespace arrow

View File

@ -1,379 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_BUFFER_BUILDER_H
#define ARROW_BUFFER_BUILDER_H
#include <algorithm>
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <utility>
#include "arrow/buffer.h"
#include "arrow/status.h"
#include "arrow/util/bit-util.h"
#include "arrow/util/macros.h"
#include "arrow/util/ubsan.h"
#include "arrow/util/visibility.h"
namespace arrow {
// ----------------------------------------------------------------------
// Buffer builder classes
/// \class BufferBuilder
/// \brief A class for incrementally building a contiguous chunk of in-memory
/// data
class ARROW_EXPORT BufferBuilder {
public:
explicit BufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
: pool_(pool),
data_(/*ensure never null to make ubsan happy and avoid check penalties below*/
&util::internal::non_null_filler),
capacity_(0),
size_(0) {}
/// \brief Resize the buffer to the nearest multiple of 64 bytes
///
/// \param new_capacity the new capacity of the of the builder. Will be
/// rounded up to a multiple of 64 bytes for padding \param shrink_to_fit if
/// new capacity is smaller than the existing size, reallocate internal
/// buffer. Set to false to avoid reallocations when shrinking the builder.
/// \return Status
Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
// Resize(0) is a no-op
if (new_capacity == 0) {
return Status::OK();
}
if (buffer_ == NULLPTR) {
ARROW_RETURN_NOT_OK(AllocateResizableBuffer(pool_, new_capacity, &buffer_));
} else {
ARROW_RETURN_NOT_OK(buffer_->Resize(new_capacity, shrink_to_fit));
}
capacity_ = buffer_->capacity();
data_ = buffer_->mutable_data();
return Status::OK();
}
/// \brief Ensure that builder can accommodate the additional number of bytes
/// without the need to perform allocations
///
/// \param[in] additional_bytes number of additional bytes to make space for
/// \return Status
Status Reserve(const int64_t additional_bytes) {
auto min_capacity = size_ + additional_bytes;
if (min_capacity <= capacity_) {
return Status::OK();
}
return Resize(GrowByFactor(capacity_, min_capacity), false);
}
/// \brief Return a capacity expanded by an unspecified growth factor
static int64_t GrowByFactor(int64_t current_capacity, int64_t new_capacity) {
// NOTE: Doubling isn't a great overallocation practice
// see https://github.com/facebook/folly/blob/master/folly/docs/FBVector.md
// for discussion.
// Grow exactly if a large upsize (the caller might know the exact final size).
// Otherwise overallocate by 1.5 to keep a linear amortized cost.
return std::max(new_capacity, current_capacity * 3 / 2);
}
/// \brief Append the given data to the buffer
///
/// The buffer is automatically expanded if necessary.
Status Append(const void* data, const int64_t length) {
if (ARROW_PREDICT_FALSE(size_ + length > capacity_)) {
ARROW_RETURN_NOT_OK(Resize(GrowByFactor(capacity_, size_ + length), false));
}
UnsafeAppend(data, length);
return Status::OK();
}
/// \brief Append copies of a value to the buffer
///
/// The buffer is automatically expanded if necessary.
Status Append(const int64_t num_copies, uint8_t value) {
ARROW_RETURN_NOT_OK(Reserve(num_copies));
UnsafeAppend(num_copies, value);
return Status::OK();
}
// Advance pointer and zero out memory
Status Advance(const int64_t length) { return Append(length, 0); }
// Advance pointer, but don't allocate or zero memory
void UnsafeAdvance(const int64_t length) { size_ += length; }
// Unsafe methods don't check existing size
void UnsafeAppend(const void* data, const int64_t length) {
memcpy(data_ + size_, data, static_cast<size_t>(length));
size_ += length;
}
void UnsafeAppend(const int64_t num_copies, uint8_t value) {
memset(data_ + size_, value, static_cast<size_t>(num_copies));
size_ += num_copies;
}
/// \brief Return result of builder as a Buffer object.
///
/// The builder is reset and can be reused afterwards.
///
/// \param[out] out the finalized Buffer object
/// \param shrink_to_fit if the buffer size is smaller than its capacity,
/// reallocate to fit more tightly in memory. Set to false to avoid
/// a reallocation, at the expense of potentially more memory consumption.
/// \return Status
Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
ARROW_RETURN_NOT_OK(Resize(size_, shrink_to_fit));
if (size_ != 0) buffer_->ZeroPadding();
*out = buffer_;
if (*out == NULLPTR) {
ARROW_RETURN_NOT_OK(AllocateBuffer(pool_, 0, out));
}
Reset();
return Status::OK();
}
void Reset() {
buffer_ = NULLPTR;
capacity_ = size_ = 0;
}
/// \brief Set size to a smaller value without modifying builder
/// contents. For reusable BufferBuilder classes
/// \param[in] position must be non-negative and less than or equal
/// to the current length()
void Rewind(int64_t position) { size_ = position; }
int64_t capacity() const { return capacity_; }
int64_t length() const { return size_; }
const uint8_t* data() const { return data_; }
uint8_t* mutable_data() { return data_; }
private:
std::shared_ptr<ResizableBuffer> buffer_;
MemoryPool* pool_;
uint8_t* data_;
int64_t capacity_;
int64_t size_;
};
template <typename T, typename Enable = void>
class TypedBufferBuilder;
/// \brief A BufferBuilder for building a buffer of arithmetic elements
template <typename T>
class TypedBufferBuilder<T, typename std::enable_if<std::is_arithmetic<T>::value>::type> {
public:
explicit TypedBufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
: bytes_builder_(pool) {}
Status Append(T value) {
return bytes_builder_.Append(reinterpret_cast<uint8_t*>(&value), sizeof(T));
}
Status Append(const T* values, int64_t num_elements) {
return bytes_builder_.Append(reinterpret_cast<const uint8_t*>(values),
num_elements * sizeof(T));
}
Status Append(const int64_t num_copies, T value) {
ARROW_RETURN_NOT_OK(Reserve(num_copies + length()));
UnsafeAppend(num_copies, value);
return Status::OK();
}
void UnsafeAppend(T value) {
bytes_builder_.UnsafeAppend(reinterpret_cast<uint8_t*>(&value), sizeof(T));
}
void UnsafeAppend(const T* values, int64_t num_elements) {
bytes_builder_.UnsafeAppend(reinterpret_cast<const uint8_t*>(values),
num_elements * sizeof(T));
}
template <typename Iter>
void UnsafeAppend(Iter values_begin, Iter values_end) {
int64_t num_elements = static_cast<int64_t>(std::distance(values_begin, values_end));
auto data = mutable_data() + length();
bytes_builder_.UnsafeAdvance(num_elements * sizeof(T));
std::copy(values_begin, values_end, data);
}
void UnsafeAppend(const int64_t num_copies, T value) {
auto data = mutable_data() + length();
bytes_builder_.UnsafeAppend(num_copies * sizeof(T), 0);
for (const auto end = data + num_copies; data != end; ++data) {
*data = value;
}
}
Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
return bytes_builder_.Resize(new_capacity * sizeof(T), shrink_to_fit);
}
Status Reserve(const int64_t additional_elements) {
return bytes_builder_.Reserve(additional_elements * sizeof(T));
}
Status Advance(const int64_t length) {
return bytes_builder_.Advance(length * sizeof(T));
}
Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
return bytes_builder_.Finish(out, shrink_to_fit);
}
void Reset() { bytes_builder_.Reset(); }
int64_t length() const { return bytes_builder_.length() / sizeof(T); }
int64_t capacity() const { return bytes_builder_.capacity() / sizeof(T); }
const T* data() const { return reinterpret_cast<const T*>(bytes_builder_.data()); }
T* mutable_data() { return reinterpret_cast<T*>(bytes_builder_.mutable_data()); }
private:
BufferBuilder bytes_builder_;
};
/// \brief A BufferBuilder for building a buffer containing a bitmap
template <>
class TypedBufferBuilder<bool> {
public:
explicit TypedBufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
: bytes_builder_(pool) {}
Status Append(bool value) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(value);
return Status::OK();
}
Status Append(const uint8_t* valid_bytes, int64_t num_elements) {
ARROW_RETURN_NOT_OK(Reserve(num_elements));
UnsafeAppend(valid_bytes, num_elements);
return Status::OK();
}
Status Append(const int64_t num_copies, bool value) {
ARROW_RETURN_NOT_OK(Reserve(num_copies));
UnsafeAppend(num_copies, value);
return Status::OK();
}
void UnsafeAppend(bool value) {
BitUtil::SetBitTo(mutable_data(), bit_length_, value);
if (!value) {
++false_count_;
}
++bit_length_;
}
void UnsafeAppend(const uint8_t* bytes, int64_t num_elements) {
if (num_elements == 0) return;
int64_t i = 0;
internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
bool value = bytes[i++];
false_count_ += !value;
return value;
});
bit_length_ += num_elements;
}
void UnsafeAppend(const int64_t num_copies, bool value) {
BitUtil::SetBitsTo(mutable_data(), bit_length_, num_copies, value);
false_count_ += num_copies * !value;
bit_length_ += num_copies;
}
template <bool count_falses, typename Generator>
void UnsafeAppend(const int64_t num_elements, Generator&& gen) {
if (num_elements == 0) return;
if (count_falses) {
internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
bool value = gen();
false_count_ += !value;
return value;
});
} else {
internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements,
std::forward<Generator>(gen));
}
bit_length_ += num_elements;
}
Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
const int64_t old_byte_capacity = bytes_builder_.capacity();
ARROW_RETURN_NOT_OK(
bytes_builder_.Resize(BitUtil::BytesForBits(new_capacity), shrink_to_fit));
// Resize() may have chosen a larger capacity (e.g. for padding),
// so ask it again before calling memset().
const int64_t new_byte_capacity = bytes_builder_.capacity();
if (new_byte_capacity > old_byte_capacity) {
// The additional buffer space is 0-initialized for convenience,
// so that other methods can simply bump the length.
memset(mutable_data() + old_byte_capacity, 0,
static_cast<size_t>(new_byte_capacity - old_byte_capacity));
}
return Status::OK();
}
Status Reserve(const int64_t additional_elements) {
return Resize(
BufferBuilder::GrowByFactor(bit_length_, bit_length_ + additional_elements),
false);
}
Status Advance(const int64_t length) {
ARROW_RETURN_NOT_OK(Reserve(length));
bit_length_ += length;
false_count_ += length;
return Status::OK();
}
Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
// set bytes_builder_.size_ == byte size of data
bytes_builder_.UnsafeAdvance(BitUtil::BytesForBits(bit_length_) -
bytes_builder_.length());
bit_length_ = false_count_ = 0;
return bytes_builder_.Finish(out, shrink_to_fit);
}
void Reset() {
bytes_builder_.Reset();
bit_length_ = false_count_ = 0;
}
int64_t length() const { return bit_length_; }
int64_t capacity() const { return bytes_builder_.capacity() * 8; }
const uint8_t* data() const { return bytes_builder_.data(); }
uint8_t* mutable_data() { return bytes_builder_.mutable_data(); }
int64_t false_count() const { return false_count_; }
private:
BufferBuilder bytes_builder_;
int64_t bit_length_ = 0;
int64_t false_count_ = 0;
};
} // namespace arrow
#endif // ARROW_BUFFER_BUILDER_H

View File

@ -1,444 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_BUFFER_H
#define ARROW_BUFFER_H
#include <cstdint>
#include <cstring>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/string_view.h"
#include "arrow/util/visibility.h"
namespace arrow {
// ----------------------------------------------------------------------
// Buffer classes
/// \class Buffer
/// \brief Object containing a pointer to a piece of contiguous memory with a
/// particular size.
///
/// Buffers have two related notions of length: size and capacity. Size is
/// the number of bytes that might have valid data. Capacity is the number
/// of bytes that were allocated for the buffer in total.
///
/// The Buffer base class does not own its memory, but subclasses often do.
///
/// The following invariant is always true: Size <= Capacity
class ARROW_EXPORT Buffer {
public:
/// \brief Construct from buffer and size without copying memory
///
/// \param[in] data a memory buffer
/// \param[in] size buffer size
///
/// \note The passed memory must be kept alive through some other means
Buffer(const uint8_t* data, int64_t size)
: is_mutable_(false),
data_(data),
mutable_data_(NULLPTR),
size_(size),
capacity_(size) {}
/// \brief Construct from string_view without copying memory
///
/// \param[in] data a string_view object
///
/// \note The memory viewed by data must not be deallocated in the lifetime of the
/// Buffer; temporary rvalue strings must be stored in an lvalue somewhere
explicit Buffer(util::string_view data)
: Buffer(reinterpret_cast<const uint8_t*>(data.data()),
static_cast<int64_t>(data.size())) {}
virtual ~Buffer() = default;
/// An offset into data that is owned by another buffer, but we want to be
/// able to retain a valid pointer to it even after other shared_ptr's to the
/// parent buffer have been destroyed
///
/// This method makes no assertions about alignment or padding of the buffer but
/// in general we expected buffers to be aligned and padded to 64 bytes. In the future
/// we might add utility methods to help determine if a buffer satisfies this contract.
Buffer(const std::shared_ptr<Buffer>& parent, const int64_t offset, const int64_t size)
: Buffer(parent->data() + offset, size) {
parent_ = parent;
}
uint8_t operator[](std::size_t i) const { return data_[i]; }
bool is_mutable() const { return is_mutable_; }
/// \brief Construct a new std::string with a hexadecimal representation of the buffer.
/// \return std::string
std::string ToHexString();
/// Return true if both buffers are the same size and contain the same bytes
/// up to the number of compared bytes
bool Equals(const Buffer& other, int64_t nbytes) const;
/// Return true if both buffers are the same size and contain the same bytes
bool Equals(const Buffer& other) const;
/// Copy a section of the buffer into a new Buffer.
Status Copy(const int64_t start, const int64_t nbytes, MemoryPool* pool,
std::shared_ptr<Buffer>* out) const;
/// Copy a section of the buffer using the default memory pool into a new Buffer.
Status Copy(const int64_t start, const int64_t nbytes,
std::shared_ptr<Buffer>* out) const;
/// Zero bytes in padding, i.e. bytes between size_ and capacity_.
void ZeroPadding() {
#ifndef NDEBUG
CheckMutable();
#endif
// A zero-capacity buffer can have a null data pointer
if (capacity_ != 0) {
memset(mutable_data_ + size_, 0, static_cast<size_t>(capacity_ - size_));
}
}
/// \brief Construct a new buffer that owns its memory from a std::string
///
/// \param[in] data a std::string object
/// \param[in] pool a memory pool
/// \param[out] out the created buffer
///
/// \return Status message
static Status FromString(const std::string& data, MemoryPool* pool,
std::shared_ptr<Buffer>* out);
/// \brief Construct a new buffer that owns its memory from a std::string
/// using the default memory pool
static Status FromString(const std::string& data, std::shared_ptr<Buffer>* out);
/// \brief Construct an immutable buffer that takes ownership of the contents
/// of an std::string
/// \param[in] data an rvalue-reference of a string
/// \return a new Buffer instance
static std::shared_ptr<Buffer> FromString(std::string&& data);
/// \brief Create buffer referencing typed memory with some length without
/// copying
/// \param[in] data the typed memory as C array
/// \param[in] length the number of values in the array
/// \return a new shared_ptr<Buffer>
template <typename T, typename SizeType = int64_t>
static std::shared_ptr<Buffer> Wrap(const T* data, SizeType length) {
return std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(data),
static_cast<int64_t>(sizeof(T) * length));
}
/// \brief Create buffer referencing std::vector with some length without
/// copying
/// \param[in] data the vector to be referenced. If this vector is changed,
/// the buffer may become invalid
/// \return a new shared_ptr<Buffer>
template <typename T>
static std::shared_ptr<Buffer> Wrap(const std::vector<T>& data) {
return std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(data.data()),
static_cast<int64_t>(sizeof(T) * data.size()));
}
/// \brief Copy buffer contents into a new std::string
/// \return std::string
/// \note Can throw std::bad_alloc if buffer is large
std::string ToString() const;
/// \brief View buffer contents as a util::string_view
/// \return util::string_view
explicit operator util::string_view() const {
return util::string_view(reinterpret_cast<const char*>(data_), size_);
}
/// \brief Return a pointer to the buffer's data
const uint8_t* data() const { return data_; }
/// \brief Return a writable pointer to the buffer's data
///
/// The buffer has to be mutable. Otherwise, an assertion may be thrown
/// or a null pointer may be returned.
uint8_t* mutable_data() {
#ifndef NDEBUG
CheckMutable();
#endif
return mutable_data_;
}
/// \brief Return the buffer's size in bytes
int64_t size() const { return size_; }
/// \brief Return the buffer's capacity (number of allocated bytes)
int64_t capacity() const { return capacity_; }
std::shared_ptr<Buffer> parent() const { return parent_; }
protected:
bool is_mutable_;
const uint8_t* data_;
uint8_t* mutable_data_;
int64_t size_;
int64_t capacity_;
// null by default, but may be set
std::shared_ptr<Buffer> parent_;
void CheckMutable() const;
private:
ARROW_DISALLOW_COPY_AND_ASSIGN(Buffer);
};
using BufferVector = std::vector<std::shared_ptr<Buffer>>;
/// \defgroup buffer-slicing-functions Functions for slicing buffers
///
/// @{
/// \brief Construct a view on a buffer at the given offset and length.
///
/// This function cannot fail and does not check for errors (except in debug builds)
static inline std::shared_ptr<Buffer> SliceBuffer(const std::shared_ptr<Buffer>& buffer,
const int64_t offset,
const int64_t length) {
return std::make_shared<Buffer>(buffer, offset, length);
}
/// \brief Construct a view on a buffer at the given offset, up to the buffer's end.
///
/// This function cannot fail and does not check for errors (except in debug builds)
static inline std::shared_ptr<Buffer> SliceBuffer(const std::shared_ptr<Buffer>& buffer,
const int64_t offset) {
int64_t length = buffer->size() - offset;
return SliceBuffer(buffer, offset, length);
}
/// \brief Like SliceBuffer, but construct a mutable buffer slice.
///
/// If the parent buffer is not mutable, behavior is undefined (it may abort
/// in debug builds).
ARROW_EXPORT
std::shared_ptr<Buffer> SliceMutableBuffer(const std::shared_ptr<Buffer>& buffer,
const int64_t offset, const int64_t length);
/// \brief Like SliceBuffer, but construct a mutable buffer slice.
///
/// If the parent buffer is not mutable, behavior is undefined (it may abort
/// in debug builds).
static inline std::shared_ptr<Buffer> SliceMutableBuffer(
const std::shared_ptr<Buffer>& buffer, const int64_t offset) {
int64_t length = buffer->size() - offset;
return SliceMutableBuffer(buffer, offset, length);
}
/// @}
/// \class MutableBuffer
/// \brief A Buffer whose contents can be mutated. May or may not own its data.
class ARROW_EXPORT MutableBuffer : public Buffer {
public:
MutableBuffer(uint8_t* data, const int64_t size) : Buffer(data, size) {
mutable_data_ = data;
is_mutable_ = true;
}
MutableBuffer(const std::shared_ptr<Buffer>& parent, const int64_t offset,
const int64_t size);
/// \brief Create buffer referencing typed memory with some length
/// \param[in] data the typed memory as C array
/// \param[in] length the number of values in the array
/// \return a new shared_ptr<Buffer>
template <typename T, typename SizeType = int64_t>
static std::shared_ptr<Buffer> Wrap(T* data, SizeType length) {
return std::make_shared<MutableBuffer>(reinterpret_cast<uint8_t*>(data),
static_cast<int64_t>(sizeof(T) * length));
}
protected:
MutableBuffer() : Buffer(NULLPTR, 0) {}
};
/// \class ResizableBuffer
/// \brief A mutable buffer that can be resized
class ARROW_EXPORT ResizableBuffer : public MutableBuffer {
public:
/// Change buffer reported size to indicated size, allocating memory if
/// necessary. This will ensure that the capacity of the buffer is a multiple
/// of 64 bytes as defined in Layout.md.
/// Consider using ZeroPadding afterwards, to conform to the Arrow layout
/// specification.
///
/// @param new_size The new size for the buffer.
/// @param shrink_to_fit Whether to shrink the capacity if new size < current size
virtual Status Resize(const int64_t new_size, bool shrink_to_fit = true) = 0;
/// Ensure that buffer has enough memory allocated to fit the indicated
/// capacity (and meets the 64 byte padding requirement in Layout.md).
/// It does not change buffer's reported size and doesn't zero the padding.
virtual Status Reserve(const int64_t new_capacity) = 0;
template <class T>
Status TypedResize(const int64_t new_nb_elements, bool shrink_to_fit = true) {
return Resize(sizeof(T) * new_nb_elements, shrink_to_fit);
}
template <class T>
Status TypedReserve(const int64_t new_nb_elements) {
return Reserve(sizeof(T) * new_nb_elements);
}
protected:
ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {}
};
/// \defgroup buffer-allocation-functions Functions for allocating buffers
///
/// @{
/// \brief Allocate a fixed size mutable buffer from a memory pool, zero its padding.
///
/// \param[in] pool a memory pool
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer (contains padding)
///
/// \return Status message
ARROW_EXPORT
Status AllocateBuffer(MemoryPool* pool, const int64_t size, std::shared_ptr<Buffer>* out);
/// \brief Allocate a fixed size mutable buffer from a memory pool, zero its padding.
///
/// \param[in] pool a memory pool
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer (contains padding)
///
/// \return Status message
ARROW_EXPORT
Status AllocateBuffer(MemoryPool* pool, const int64_t size, std::unique_ptr<Buffer>* out);
/// \brief Allocate a fixed-size mutable buffer from the default memory pool
///
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer (contains padding)
///
/// \return Status message
ARROW_EXPORT
Status AllocateBuffer(const int64_t size, std::shared_ptr<Buffer>* out);
/// \brief Allocate a fixed-size mutable buffer from the default memory pool
///
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer (contains padding)
///
/// \return Status message
ARROW_EXPORT
Status AllocateBuffer(const int64_t size, std::unique_ptr<Buffer>* out);
/// \brief Allocate a resizeable buffer from a memory pool, zero its padding.
///
/// \param[in] pool a memory pool
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer
///
/// \return Status message
ARROW_EXPORT
Status AllocateResizableBuffer(MemoryPool* pool, const int64_t size,
std::shared_ptr<ResizableBuffer>* out);
/// \brief Allocate a resizeable buffer from a memory pool, zero its padding.
///
/// \param[in] pool a memory pool
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer
///
/// \return Status message
ARROW_EXPORT
Status AllocateResizableBuffer(MemoryPool* pool, const int64_t size,
std::unique_ptr<ResizableBuffer>* out);
/// \brief Allocate a resizeable buffer from the default memory pool
///
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer
///
/// \return Status message
ARROW_EXPORT
Status AllocateResizableBuffer(const int64_t size, std::shared_ptr<ResizableBuffer>* out);
/// \brief Allocate a resizeable buffer from the default memory pool
///
/// \param[in] size size of buffer to allocate
/// \param[out] out the allocated buffer
///
/// \return Status message
ARROW_EXPORT
Status AllocateResizableBuffer(const int64_t size, std::unique_ptr<ResizableBuffer>* out);
/// \brief Allocate a bitmap buffer from a memory pool
/// no guarantee on values is provided.
///
/// \param[in] pool memory pool to allocate memory from
/// \param[in] length size in bits of bitmap to allocate
/// \param[out] out the resulting buffer
///
/// \return Status message
ARROW_EXPORT
Status AllocateBitmap(MemoryPool* pool, int64_t length, std::shared_ptr<Buffer>* out);
/// \brief Allocate a zero-initialized bitmap buffer from a memory pool
///
/// \param[in] pool memory pool to allocate memory from
/// \param[in] length size in bits of bitmap to allocate
/// \param[out] out the resulting buffer (zero-initialized).
///
/// \return Status message
ARROW_EXPORT
Status AllocateEmptyBitmap(MemoryPool* pool, int64_t length,
std::shared_ptr<Buffer>* out);
/// \brief Allocate a zero-initialized bitmap buffer from the default memory pool
///
/// \param[in] length size in bits of bitmap to allocate
/// \param[out] out the resulting buffer
///
/// \return Status message
ARROW_EXPORT
Status AllocateEmptyBitmap(int64_t length, std::shared_ptr<Buffer>* out);
/// \brief Concatenate multiple buffers into a single buffer
///
/// \param[in] buffers to be concatenated
/// \param[in] pool memory pool to allocate the new buffer from
/// \param[out] out the concatenated buffer
///
/// \return Status
ARROW_EXPORT
Status ConcatenateBuffers(const BufferVector& buffers, MemoryPool* pool,
std::shared_ptr<Buffer>* out);
/// @}
} // namespace arrow
#endif // ARROW_BUFFER_H

View File

@ -1,58 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/array/builder_adaptive.h" // IWYU pragma: export
#include "arrow/array/builder_base.h" // IWYU pragma: export
#include "arrow/array/builder_binary.h" // IWYU pragma: export
#include "arrow/array/builder_decimal.h" // IWYU pragma: export
#include "arrow/array/builder_dict.h" // IWYU pragma: export
#include "arrow/array/builder_nested.h" // IWYU pragma: export
#include "arrow/array/builder_primitive.h" // IWYU pragma: export
#include "arrow/array/builder_time.h" // IWYU pragma: export
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class DataType;
class MemoryPool;
/// \brief Construct an empty ArrayBuilder corresponding to the data
/// type
/// \param[in] pool the MemoryPool to use for allocations
/// \param[in] type an instance of DictionaryType
/// \param[out] out the created ArrayBuilder
ARROW_EXPORT
Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
std::unique_ptr<ArrayBuilder>* out);
/// \brief Construct an empty DictionaryBuilder initialized optionally
/// with a pre-existing dictionary
/// \param[in] pool the MemoryPool to use for allocations
/// \param[in] type an instance of DictionaryType
/// \param[in] dictionary the initial dictionary, if any. May be nullptr
/// \param[out] out the created ArrayBuilder
ARROW_EXPORT
Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
const std::shared_ptr<Array>& dictionary,
std::unique_ptr<ArrayBuilder>* out);
} // namespace arrow

View File

@ -1,101 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Functions for comparing Arrow data structures
#ifndef ARROW_COMPARE_H
#define ARROW_COMPARE_H
#include <cstdint>
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
class Tensor;
class SparseTensor;
struct Scalar;
static constexpr double kDefaultAbsoluteTolerance = 1E-5;
/// A container of options for equality comparisons
class EqualOptions {
public:
/// Whether or not NaNs are considered equal.
bool nans_equal() const { return nans_equal_; }
/// Return a new EqualOptions object with the "nans_equal" property changed.
EqualOptions nans_equal(bool v) const {
auto res = EqualOptions(*this);
res.nans_equal_ = v;
return res;
}
/// The absolute tolerance for approximate comparisons of floating-point values.
double atol() const { return atol_; }
/// Return a new EqualOptions object with the "atol" property changed.
EqualOptions atol(double v) const {
auto res = EqualOptions(*this);
res.atol_ = v;
return res;
}
static EqualOptions Defaults() { return EqualOptions(); }
protected:
double atol_ = kDefaultAbsoluteTolerance;
bool nans_equal_ = false;
};
/// Returns true if the arrays are exactly equal
bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right,
const EqualOptions& = EqualOptions::Defaults());
bool ARROW_EXPORT TensorEquals(const Tensor& left, const Tensor& right);
/// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal
bool ARROW_EXPORT SparseTensorEquals(const SparseTensor& left, const SparseTensor& right);
/// Returns true if the arrays are approximately equal. For non-floating point
/// types, this is equivalent to ArrayEquals(left, right)
bool ARROW_EXPORT ArrayApproxEquals(const Array& left, const Array& right,
const EqualOptions& = EqualOptions::Defaults());
/// Returns true if indicated equal-length segment of arrays is exactly equal
bool ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right,
int64_t start_idx, int64_t end_idx,
int64_t other_start_idx);
/// Returns true if the type metadata are exactly equal
/// \param[in] left a DataType
/// \param[in] right a DataType
/// \param[in] check_metadata whether to compare KeyValueMetadata for child
/// fields
bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
bool check_metadata = true);
/// Returns true if scalars are equal
/// \param[in] left a Scalar
/// \param[in] right a Scalar
bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right);
} // namespace arrow
#endif // ARROW_COMPARE_H

View File

@ -1,33 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_API_H
#define ARROW_COMPUTE_API_H
#include "arrow/compute/context.h" // IWYU pragma: export
#include "arrow/compute/kernel.h" // IWYU pragma: export
#include "arrow/compute/kernels/boolean.h" // IWYU pragma: export
#include "arrow/compute/kernels/cast.h" // IWYU pragma: export
#include "arrow/compute/kernels/compare.h" // IWYU pragma: export
#include "arrow/compute/kernels/count.h" // IWYU pragma: export
#include "arrow/compute/kernels/hash.h" // IWYU pragma: export
#include "arrow/compute/kernels/mean.h" // IWYU pragma: export
#include "arrow/compute/kernels/sum.h" // IWYU pragma: export
#include "arrow/compute/kernels/take.h" // IWYU pragma: export
#endif // ARROW_COMPUTE_API_H

View File

@ -1,97 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <vector>
#include "arrow/testing/gtest_util.h"
#include "arrow/util/cpu-info.h"
namespace arrow {
namespace compute {
using internal::CpuInfo;
static CpuInfo* cpu_info = CpuInfo::GetInstance();
static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE);
static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::L2_CACHE);
static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::L3_CACHE);
static const int64_t kCantFitInL3Size = kL3Size * 4;
static const std::vector<int64_t> kMemorySizes = {kL1Size, kL2Size, kL3Size,
kCantFitInL3Size};
template <typename Func>
struct BenchmarkArgsType;
// Pattern matching that extracts the vector element type of Benchmark::Args()
template <typename Values>
struct BenchmarkArgsType<benchmark::internal::Benchmark* (
benchmark::internal::Benchmark::*)(const std::vector<Values>&)> {
using type = Values;
};
// Benchmark changed its parameter type between releases from
// int to int64_t. As it doesn't have version macros, we need
// to apply C++ template magic.
using ArgsType =
typename BenchmarkArgsType<decltype(&benchmark::internal::Benchmark::Args)>::type;
void BenchmarkSetArgsWithSizes(benchmark::internal::Benchmark* bench,
const std::vector<int64_t>& sizes = kMemorySizes) {
bench->Unit(benchmark::kMicrosecond);
for (auto size : sizes)
for (auto nulls : std::vector<ArgsType>({0, 1, 10, 50}))
bench->Args({static_cast<ArgsType>(size), nulls});
}
void BenchmarkSetArgs(benchmark::internal::Benchmark* bench) {
BenchmarkSetArgsWithSizes(bench, kMemorySizes);
}
void RegressionSetArgs(benchmark::internal::Benchmark* bench) {
// Regression do not need to account for cache hierarchy, thus optimize for
// the best case.
BenchmarkSetArgsWithSizes(bench, {kL1Size});
}
// RAII struct to handle some of the boilerplate in regression benchmarks
struct RegressionArgs {
// size of memory tested (per iteration) in bytes
const int64_t size;
// proportion of nulls in generated arrays
const double null_proportion;
explicit RegressionArgs(benchmark::State& state)
: size(state.range(0)),
null_proportion(static_cast<double>(state.range(1)) / 100.0),
state_(state) {}
~RegressionArgs() {
state_.counters["size"] = static_cast<double>(size);
state_.counters["null_percent"] = static_cast<double>(state_.range(1));
state_.SetBytesProcessed(state_.iterations() * size);
}
private:
benchmark::State& state_;
};
} // namespace compute
} // namespace arrow

View File

@ -1,82 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_CONTEXT_H
#define ARROW_COMPUTE_CONTEXT_H
#include <cstdint>
#include <memory>
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Buffer;
namespace internal {
class CpuInfo;
} // namespace internal
namespace compute {
#define RETURN_IF_ERROR(ctx) \
if (ARROW_PREDICT_FALSE(ctx->HasError())) { \
Status s = ctx->status(); \
ctx->ResetStatus(); \
return s; \
}
/// \brief Container for variables and options used by function evaluation
class ARROW_EXPORT FunctionContext {
public:
explicit FunctionContext(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
MemoryPool* memory_pool() const;
/// \brief Allocate buffer from the context's memory pool
Status Allocate(const int64_t nbytes, std::shared_ptr<Buffer>* out);
/// \brief Indicate that an error has occurred, to be checked by a parent caller
/// \param[in] status a Status instance
///
/// \note Will not overwrite a prior set Status, so we will have the first
/// error that occurred until FunctionContext::ResetStatus is called
void SetStatus(const Status& status);
/// \brief Clear any error status
void ResetStatus();
/// \brief Return true if an error has occurred
bool HasError() const { return !status_.ok(); }
/// \brief Return the current status of the context
const Status& status() const { return status_; }
internal::CpuInfo* cpu_info() const { return cpu_info_; }
private:
Status status_;
MemoryPool* pool_;
internal::CpuInfo* cpu_info_;
};
} // namespace compute
} // namespace arrow
#endif // ARROW_COMPUTE_CONTEXT_H

View File

@ -1,261 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <string>
#include "arrow/compute/type_fwd.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
namespace arrow {
namespace compute {
class LogicalType;
class ExprVisitor;
class Operation;
/// \brief Base class for all analytic expressions. Expressions may represent
/// data values (scalars, arrays, tables)
class ARROW_EXPORT Expr {
public:
/// \brief Instantiate expression from an abstract operation
/// \param[in] op the operation that generates the expression
explicit Expr(ConstOpPtr op);
virtual ~Expr() = default;
/// \brief A unique string identifier for the kind of expression
virtual std::string kind() const = 0;
/// \brief Accept expression visitor
/// TODO(wesm)
// virtual Status Accept(ExprVisitor* visitor) const = 0;
/// \brief The underlying operation
ConstOpPtr op() const { return op_; }
protected:
ConstOpPtr op_;
};
/// The value cardinality: one or many. These correspond to the arrow::Scalar
/// and arrow::Array types
enum class ValueRank { SCALAR, ARRAY };
/// \brief Base class for a data-generated expression with a fixed and known
/// type. This includes arrays and scalars
class ARROW_EXPORT ValueExpr : public Expr {
public:
/// \brief The name of the expression, if any. The default is unnamed
// virtual const ExprName& name() const;
LogicalTypePtr type() const;
/// \brief The value cardinality (scalar or array) of the expression
virtual ValueRank rank() const = 0;
protected:
ValueExpr(ConstOpPtr op, LogicalTypePtr type);
/// \brief The semantic data type of the expression
LogicalTypePtr type_;
};
class ARROW_EXPORT ArrayExpr : public ValueExpr {
protected:
using ValueExpr::ValueExpr;
std::string kind() const override;
ValueRank rank() const override;
};
class ARROW_EXPORT ScalarExpr : public ValueExpr {
protected:
using ValueExpr::ValueExpr;
std::string kind() const override;
ValueRank rank() const override;
};
namespace value {
// These are mixin classes to provide a type hierarchy for values identify
class ValueMixin {};
class Null : public ValueMixin {};
class Bool : public ValueMixin {};
class Number : public ValueMixin {};
class Integer : public Number {};
class SignedInteger : public Integer {};
class Int8 : public SignedInteger {};
class Int16 : public SignedInteger {};
class Int32 : public SignedInteger {};
class Int64 : public SignedInteger {};
class UnsignedInteger : public Integer {};
class UInt8 : public UnsignedInteger {};
class UInt16 : public UnsignedInteger {};
class UInt32 : public UnsignedInteger {};
class UInt64 : public UnsignedInteger {};
class Floating : public Number {};
class Float16 : public Floating {};
class Float32 : public Floating {};
class Float64 : public Floating {};
class Binary : public ValueMixin {};
class Utf8 : public Binary {};
class List : public ValueMixin {};
class Struct : public ValueMixin {};
} // namespace value
#define SIMPLE_EXPR_FACTORY(NAME) ARROW_EXPORT ExprPtr NAME(ConstOpPtr op);
namespace scalar {
#define DECLARE_SCALAR_EXPR(TYPE) \
class ARROW_EXPORT TYPE : public ScalarExpr, public value::TYPE { \
public: \
explicit TYPE(ConstOpPtr op); \
using ScalarExpr::kind; \
};
DECLARE_SCALAR_EXPR(Null)
DECLARE_SCALAR_EXPR(Bool)
DECLARE_SCALAR_EXPR(Int8)
DECLARE_SCALAR_EXPR(Int16)
DECLARE_SCALAR_EXPR(Int32)
DECLARE_SCALAR_EXPR(Int64)
DECLARE_SCALAR_EXPR(UInt8)
DECLARE_SCALAR_EXPR(UInt16)
DECLARE_SCALAR_EXPR(UInt32)
DECLARE_SCALAR_EXPR(UInt64)
DECLARE_SCALAR_EXPR(Float16)
DECLARE_SCALAR_EXPR(Float32)
DECLARE_SCALAR_EXPR(Float64)
DECLARE_SCALAR_EXPR(Binary)
DECLARE_SCALAR_EXPR(Utf8)
#undef DECLARE_SCALAR_EXPR
SIMPLE_EXPR_FACTORY(null);
SIMPLE_EXPR_FACTORY(boolean);
SIMPLE_EXPR_FACTORY(int8);
SIMPLE_EXPR_FACTORY(int16);
SIMPLE_EXPR_FACTORY(int32);
SIMPLE_EXPR_FACTORY(int64);
SIMPLE_EXPR_FACTORY(uint8);
SIMPLE_EXPR_FACTORY(uint16);
SIMPLE_EXPR_FACTORY(uint32);
SIMPLE_EXPR_FACTORY(uint64);
SIMPLE_EXPR_FACTORY(float16);
SIMPLE_EXPR_FACTORY(float32);
SIMPLE_EXPR_FACTORY(float64);
SIMPLE_EXPR_FACTORY(binary);
SIMPLE_EXPR_FACTORY(utf8);
class ARROW_EXPORT List : public ScalarExpr, public value::List {
public:
List(ConstOpPtr op, LogicalTypePtr type);
using ScalarExpr::kind;
};
class ARROW_EXPORT Struct : public ScalarExpr, public value::Struct {
public:
Struct(ConstOpPtr op, LogicalTypePtr type);
using ScalarExpr::kind;
};
} // namespace scalar
namespace array {
#define DECLARE_ARRAY_EXPR(TYPE) \
class ARROW_EXPORT TYPE : public ArrayExpr, public value::TYPE { \
public: \
explicit TYPE(ConstOpPtr op); \
using ArrayExpr::kind; \
};
DECLARE_ARRAY_EXPR(Null)
DECLARE_ARRAY_EXPR(Bool)
DECLARE_ARRAY_EXPR(Int8)
DECLARE_ARRAY_EXPR(Int16)
DECLARE_ARRAY_EXPR(Int32)
DECLARE_ARRAY_EXPR(Int64)
DECLARE_ARRAY_EXPR(UInt8)
DECLARE_ARRAY_EXPR(UInt16)
DECLARE_ARRAY_EXPR(UInt32)
DECLARE_ARRAY_EXPR(UInt64)
DECLARE_ARRAY_EXPR(Float16)
DECLARE_ARRAY_EXPR(Float32)
DECLARE_ARRAY_EXPR(Float64)
DECLARE_ARRAY_EXPR(Binary)
DECLARE_ARRAY_EXPR(Utf8)
#undef DECLARE_ARRAY_EXPR
SIMPLE_EXPR_FACTORY(null);
SIMPLE_EXPR_FACTORY(boolean);
SIMPLE_EXPR_FACTORY(int8);
SIMPLE_EXPR_FACTORY(int16);
SIMPLE_EXPR_FACTORY(int32);
SIMPLE_EXPR_FACTORY(int64);
SIMPLE_EXPR_FACTORY(uint8);
SIMPLE_EXPR_FACTORY(uint16);
SIMPLE_EXPR_FACTORY(uint32);
SIMPLE_EXPR_FACTORY(uint64);
SIMPLE_EXPR_FACTORY(float16);
SIMPLE_EXPR_FACTORY(float32);
SIMPLE_EXPR_FACTORY(float64);
SIMPLE_EXPR_FACTORY(binary);
SIMPLE_EXPR_FACTORY(utf8);
class ARROW_EXPORT List : public ArrayExpr, public value::List {
public:
List(ConstOpPtr op, LogicalTypePtr type);
using ArrayExpr::kind;
};
class ARROW_EXPORT Struct : public ArrayExpr, public value::Struct {
public:
Struct(ConstOpPtr op, LogicalTypePtr type);
using ArrayExpr::kind;
};
} // namespace array
#undef SIMPLE_EXPR_FACTORY
template <typename T, typename ObjectType>
inline bool InheritsFrom(const ObjectType* obj) {
return dynamic_cast<const T*>(obj) != NULLPTR;
}
template <typename T, typename ObjectType>
inline bool InheritsFrom(const ObjectType& obj) {
return dynamic_cast<const T*>(&obj) != NULLPTR;
}
/// \brief Construct a ScalarExpr containing an Operation given a logical type
ARROW_EXPORT
Status GetScalarExpr(ConstOpPtr op, LogicalTypePtr ty, ExprPtr* out);
/// \brief Construct an ArrayExpr containing an Operation given a logical type
ARROW_EXPORT
Status GetArrayExpr(ConstOpPtr op, LogicalTypePtr ty, ExprPtr* out);
} // namespace compute
} // namespace arrow

View File

@ -1,289 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_KERNEL_H
#define ARROW_COMPUTE_KERNEL_H
#include <memory>
#include <utility>
#include <vector>
#include "arrow/array.h"
#include "arrow/record_batch.h"
#include "arrow/scalar.h"
#include "arrow/table.h"
#include "arrow/util/macros.h"
#include "arrow/util/memory.h"
#include "arrow/util/variant.h" // IWYU pragma: export
#include "arrow/util/visibility.h"
namespace arrow {
namespace compute {
class FunctionContext;
/// \class OpKernel
/// \brief Base class for operator kernels
///
/// Note to implementors:
/// Operator kernels are intended to be the lowest level of an analytics/compute
/// engine. They will generally not be exposed directly to end-users. Instead
/// they will be wrapped by higher level constructs (e.g. top-level functions
/// or physical execution plan nodes). These higher level constructs are
/// responsible for user input validation and returning the appropriate
/// error Status.
///
/// Due to this design, implementations of Call (the execution
/// method on subclasses) should use assertions (i.e. DCHECK) to double-check
/// parameter arguments when in higher level components returning an
/// InvalidArgument error might be more appropriate.
///
class ARROW_EXPORT OpKernel {
public:
virtual ~OpKernel() = default;
/// \brief EXPERIMENTAL The output data type of the kernel
/// \return the output type
virtual std::shared_ptr<DataType> out_type() const = 0;
};
struct Datum;
static inline bool CollectionEquals(const std::vector<Datum>& left,
const std::vector<Datum>& right);
// Datums variants may have a length. This special value indicate that the
// current variant does not have a length.
constexpr int64_t kUnknownLength = -1;
/// \class Datum
/// \brief Variant type for various Arrow C++ data structures
struct ARROW_EXPORT Datum {
enum type { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE, COLLECTION };
util::variant<decltype(NULLPTR), std::shared_ptr<Scalar>, std::shared_ptr<ArrayData>,
std::shared_ptr<ChunkedArray>, std::shared_ptr<RecordBatch>,
std::shared_ptr<Table>, std::vector<Datum>>
value;
/// \brief Empty datum, to be populated elsewhere
Datum() : value(NULLPTR) {}
Datum(const std::shared_ptr<Scalar>& value) // NOLINT implicit conversion
: value(value) {}
Datum(const std::shared_ptr<ArrayData>& value) // NOLINT implicit conversion
: value(value) {}
Datum(const std::shared_ptr<Array>& value) // NOLINT implicit conversion
: Datum(value ? value->data() : NULLPTR) {}
Datum(const std::shared_ptr<ChunkedArray>& value) // NOLINT implicit conversion
: value(value) {}
Datum(const std::shared_ptr<RecordBatch>& value) // NOLINT implicit conversion
: value(value) {}
Datum(const std::shared_ptr<Table>& value) // NOLINT implicit conversion
: value(value) {}
Datum(const std::vector<Datum>& value) // NOLINT implicit conversion
: value(value) {}
// Cast from subtypes of Array to Datum
template <typename T,
typename = typename std::enable_if<std::is_base_of<Array, T>::value>::type>
Datum(const std::shared_ptr<T>& value) // NOLINT implicit conversion
: Datum(std::shared_ptr<Array>(value)) {}
// Convenience constructors
explicit Datum(bool value) : value(std::make_shared<BooleanScalar>(value)) {}
explicit Datum(int8_t value) : value(std::make_shared<Int8Scalar>(value)) {}
explicit Datum(uint8_t value) : value(std::make_shared<UInt8Scalar>(value)) {}
explicit Datum(int16_t value) : value(std::make_shared<Int16Scalar>(value)) {}
explicit Datum(uint16_t value) : value(std::make_shared<UInt16Scalar>(value)) {}
explicit Datum(int32_t value) : value(std::make_shared<Int32Scalar>(value)) {}
explicit Datum(uint32_t value) : value(std::make_shared<UInt32Scalar>(value)) {}
explicit Datum(int64_t value) : value(std::make_shared<Int64Scalar>(value)) {}
explicit Datum(uint64_t value) : value(std::make_shared<UInt64Scalar>(value)) {}
explicit Datum(float value) : value(std::make_shared<FloatScalar>(value)) {}
explicit Datum(double value) : value(std::make_shared<DoubleScalar>(value)) {}
~Datum() {}
Datum(const Datum& other) noexcept { this->value = other.value; }
Datum& operator=(const Datum& other) noexcept {
value = other.value;
return *this;
}
// Define move constructor and move assignment, for better performance
Datum(Datum&& other) noexcept : value(std::move(other.value)) {}
Datum& operator=(Datum&& other) noexcept {
value = std::move(other.value);
return *this;
}
Datum::type kind() const {
switch (this->value.index()) {
case 0:
return Datum::NONE;
case 1:
return Datum::SCALAR;
case 2:
return Datum::ARRAY;
case 3:
return Datum::CHUNKED_ARRAY;
case 4:
return Datum::RECORD_BATCH;
case 5:
return Datum::TABLE;
case 6:
return Datum::COLLECTION;
default:
return Datum::NONE;
}
}
std::shared_ptr<ArrayData> array() const {
return util::get<std::shared_ptr<ArrayData>>(this->value);
}
std::shared_ptr<Array> make_array() const {
return MakeArray(util::get<std::shared_ptr<ArrayData>>(this->value));
}
std::shared_ptr<ChunkedArray> chunked_array() const {
return util::get<std::shared_ptr<ChunkedArray>>(this->value);
}
std::shared_ptr<RecordBatch> record_batch() const {
return util::get<std::shared_ptr<RecordBatch>>(this->value);
}
std::shared_ptr<Table> table() const {
return util::get<std::shared_ptr<Table>>(this->value);
}
const std::vector<Datum> collection() const {
return util::get<std::vector<Datum>>(this->value);
}
std::shared_ptr<Scalar> scalar() const {
return util::get<std::shared_ptr<Scalar>>(this->value);
}
bool is_array() const { return this->kind() == Datum::ARRAY; }
bool is_arraylike() const {
return this->kind() == Datum::ARRAY || this->kind() == Datum::CHUNKED_ARRAY;
}
bool is_scalar() const { return this->kind() == Datum::SCALAR; }
/// \brief The value type of the variant, if any
///
/// \return nullptr if no type
std::shared_ptr<DataType> type() const {
if (this->kind() == Datum::ARRAY) {
return util::get<std::shared_ptr<ArrayData>>(this->value)->type;
} else if (this->kind() == Datum::CHUNKED_ARRAY) {
return util::get<std::shared_ptr<ChunkedArray>>(this->value)->type();
} else if (this->kind() == Datum::SCALAR) {
return util::get<std::shared_ptr<Scalar>>(this->value)->type;
}
return NULLPTR;
}
/// \brief The value length of the variant, if any
///
/// \return kUnknownLength if no type
int64_t length() const {
if (this->kind() == Datum::ARRAY) {
return util::get<std::shared_ptr<ArrayData>>(this->value)->length;
} else if (this->kind() == Datum::CHUNKED_ARRAY) {
return util::get<std::shared_ptr<ChunkedArray>>(this->value)->length();
} else if (this->kind() == Datum::SCALAR) {
return 1;
}
return kUnknownLength;
}
bool Equals(const Datum& other) const {
if (this->kind() != other.kind()) return false;
switch (this->kind()) {
case Datum::NONE:
return true;
case Datum::SCALAR:
return internal::SharedPtrEquals(this->scalar(), other.scalar());
case Datum::ARRAY:
return internal::SharedPtrEquals(this->make_array(), other.make_array());
case Datum::CHUNKED_ARRAY:
return internal::SharedPtrEquals(this->chunked_array(), other.chunked_array());
case Datum::RECORD_BATCH:
return internal::SharedPtrEquals(this->record_batch(), other.record_batch());
case Datum::TABLE:
return internal::SharedPtrEquals(this->table(), other.table());
case Datum::COLLECTION:
return CollectionEquals(this->collection(), other.collection());
default:
return false;
}
}
};
/// \class UnaryKernel
/// \brief An array-valued function of a single input argument.
///
/// Note to implementors: Try to avoid making kernels that allocate memory if
/// the output size is a deterministic function of the Input Datum's metadata.
/// Instead separate the logic of the kernel and allocations necessary into
/// two different kernels. Some reusable kernels that allocate buffers
/// and delegate computation to another kernel are available in util-internal.h.
class ARROW_EXPORT UnaryKernel : public OpKernel {
public:
/// \brief Executes the kernel.
///
/// \param[in] ctx The function context for the kernel
/// \param[in] input The kernel input data
/// \param[out] out The output of the function. Each implementation of this
/// function might assume different things about the existing contents of out
/// (e.g. which buffers are preallocated). In the future it is expected that
/// there will be a more generic mechansim for understanding the necessary
/// contracts.
virtual Status Call(FunctionContext* ctx, const Datum& input, Datum* out) = 0;
};
/// \class BinaryKernel
/// \brief An array-valued function of a two input arguments
class ARROW_EXPORT BinaryKernel : public OpKernel {
public:
virtual Status Call(FunctionContext* ctx, const Datum& left, const Datum& right,
Datum* out) = 0;
};
static inline bool CollectionEquals(const std::vector<Datum>& left,
const std::vector<Datum>& right) {
if (left.size() != right.size()) return false;
for (size_t i = 0; i < left.size(); i++)
if (!left[i].Equals(right[i])) return false;
return true;
}
} // namespace compute
} // namespace arrow
#endif // ARROW_COMPUTE_KERNEL_H

View File

@ -1,115 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/compute/kernel.h"
namespace arrow {
class Array;
class Status;
namespace compute {
class FunctionContext;
struct Datum;
/// AggregateFunction is an interface for Aggregates
///
/// An aggregates transforms an array into single result called a state via the
/// Consume method.. State supports the merge operation via the Merge method.
/// State can be sealed into a final result via the Finalize method.
//
/// State ownership is handled by callers, thus the interface exposes 3 methods
/// for the caller to manage memory:
/// - Size
/// - New (placement new constructor invocation)
/// - Delete (state desctructor)
///
/// Design inspired by ClickHouse aggregate functions.
class AggregateFunction {
public:
/// \brief Consume an array into a state.
virtual Status Consume(const Array& input, void* state) const = 0;
/// \brief Merge states.
virtual Status Merge(const void* src, void* dst) const = 0;
/// \brief Convert state into a final result.
virtual Status Finalize(const void* src, Datum* output) const = 0;
virtual ~AggregateFunction() {}
virtual std::shared_ptr<DataType> out_type() const = 0;
/// State management methods.
virtual int64_t Size() const = 0;
virtual void New(void* ptr) const = 0;
virtual void Delete(void* ptr) const = 0;
};
/// AggregateFunction partial implementation for static type state
template <typename State>
class AggregateFunctionStaticState : public AggregateFunction {
virtual Status Consume(const Array& input, State* state) const = 0;
virtual Status Merge(const State& src, State* dst) const = 0;
virtual Status Finalize(const State& src, Datum* output) const = 0;
Status Consume(const Array& input, void* state) const final {
return Consume(input, static_cast<State*>(state));
}
Status Merge(const void* src, void* dst) const final {
return Merge(*static_cast<const State*>(src), static_cast<State*>(dst));
}
/// \brief Convert state into a final result.
Status Finalize(const void* src, Datum* output) const final {
return Finalize(*static_cast<const State*>(src), output);
}
int64_t Size() const final { return sizeof(State); }
void New(void* ptr) const final {
// By using placement-new syntax, the constructor of the State is invoked
// in the memory location defined by the caller. This only supports State
// with a parameter-less constructor.
new (ptr) State;
}
void Delete(void* ptr) const final { static_cast<State*>(ptr)->~State(); }
};
/// \brief UnaryKernel implemented by an AggregateState
class ARROW_EXPORT AggregateUnaryKernel : public UnaryKernel {
public:
explicit AggregateUnaryKernel(std::shared_ptr<AggregateFunction>& aggregate)
: aggregate_function_(aggregate) {}
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override;
std::shared_ptr<DataType> out_type() const override;
private:
std::shared_ptr<AggregateFunction> aggregate_function_;
};
} // namespace compute
} // namespace arrow

View File

@ -1,76 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_KERNELS_BOOLEAN_H
#define ARROW_COMPUTE_KERNELS_BOOLEAN_H
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
namespace compute {
struct Datum;
class FunctionContext;
/// \brief Invert the values of a boolean datum
/// \param[in] context the FunctionContext
/// \param[in] value datum to invert
/// \param[out] out resulting datum
///
/// \since 0.11.0
/// \note API not yet finalized
ARROW_EXPORT
Status Invert(FunctionContext* context, const Datum& value, Datum* out);
/// \brief Element-wise AND of two boolean datums
/// \param[in] context the FunctionContext
/// \param[in] left left operand (array)
/// \param[in] right right operand (array)
/// \param[out] out resulting datum
///
/// \since 0.11.0
/// \note API not yet finalized
ARROW_EXPORT
Status And(FunctionContext* context, const Datum& left, const Datum& right, Datum* out);
/// \brief Element-wise OR of two boolean datums
/// \param[in] context the FunctionContext
/// \param[in] left left operand (array)
/// \param[in] right right operand (array)
/// \param[out] out resulting datum
///
/// \since 0.11.0
/// \note API not yet finalized
ARROW_EXPORT
Status Or(FunctionContext* context, const Datum& left, const Datum& right, Datum* out);
/// \brief Element-wise XOR of two boolean datums
/// \param[in] context the FunctionContext
/// \param[in] left left operand (array)
/// \param[in] right right operand (array)
/// \param[out] out resulting datum
///
/// \since 0.11.0
/// \note API not yet finalized
ARROW_EXPORT
Status Xor(FunctionContext* context, const Datum& left, const Datum& right, Datum* out);
} // namespace compute
} // namespace arrow
#endif // ARROW_COMPUTE_KERNELS_CAST_H

View File

@ -1,98 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_KERNELS_CAST_H
#define ARROW_COMPUTE_KERNELS_CAST_H
#include <memory>
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
namespace compute {
struct Datum;
class FunctionContext;
class UnaryKernel;
struct ARROW_EXPORT CastOptions {
CastOptions()
: allow_int_overflow(false),
allow_time_truncate(false),
allow_float_truncate(false),
allow_invalid_utf8(false) {}
explicit CastOptions(bool safe)
: allow_int_overflow(!safe),
allow_time_truncate(!safe),
allow_float_truncate(!safe),
allow_invalid_utf8(!safe) {}
static CastOptions Safe() { return CastOptions(true); }
static CastOptions Unsafe() { return CastOptions(false); }
bool allow_int_overflow;
bool allow_time_truncate;
bool allow_float_truncate;
// Indicate if conversions from Binary/FixedSizeBinary to string must
// validate the utf8 payload.
bool allow_invalid_utf8;
};
/// \since 0.7.0
/// \note API not yet finalized
ARROW_EXPORT
Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> to_type,
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel);
/// \brief Cast from one array type to another
/// \param[in] context the FunctionContext
/// \param[in] value array to cast
/// \param[in] to_type type to cast to
/// \param[in] options casting options
/// \param[out] out resulting array
///
/// \since 0.7.0
/// \note API not yet finalized
ARROW_EXPORT
Status Cast(FunctionContext* context, const Array& value,
std::shared_ptr<DataType> to_type, const CastOptions& options,
std::shared_ptr<Array>* out);
/// \brief Cast from one value to another
/// \param[in] context the FunctionContext
/// \param[in] value datum to cast
/// \param[in] to_type type to cast to
/// \param[in] options casting options
/// \param[out] out resulting datum
///
/// \since 0.8.0
/// \note API not yet finalized
ARROW_EXPORT
Status Cast(FunctionContext* context, const Datum& value,
std::shared_ptr<DataType> to_type, const CastOptions& options, Datum* out);
} // namespace compute
} // namespace arrow
#endif // ARROW_COMPUTE_KERNELS_CAST_H

View File

@ -1,176 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/compute/kernel.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
struct Scalar;
class Status;
namespace compute {
struct Datum;
class FunctionContext;
/// CompareFunction is an interface for Comparisons
///
/// Comparisons take an array and emits a selection vector. The selection vector
/// is given in the form of a bitmask as a BooleanArray result.
class ARROW_EXPORT CompareFunction {
public:
/// Compare an array with a scalar argument.
virtual Status Compare(const ArrayData& array, const Scalar& scalar,
ArrayData* output) const = 0;
Status Compare(const ArrayData& array, const Scalar& scalar,
std::shared_ptr<ArrayData>* output) {
return Compare(array, scalar, output->get());
}
virtual Status Compare(const Scalar& scalar, const ArrayData& array,
ArrayData* output) const = 0;
Status Compare(const Scalar& scalar, const ArrayData& array,
std::shared_ptr<ArrayData>* output) {
return Compare(scalar, array, output->get());
}
/// Compare an array with an array argument.
virtual Status Compare(const ArrayData& lhs, const ArrayData& rhs,
ArrayData* output) const = 0;
Status Compare(const ArrayData& lhs, const ArrayData& rhs,
std::shared_ptr<ArrayData>* output) {
return Compare(lhs, rhs, output->get());
}
/// By default, CompareFunction emits a result bitmap.
virtual std::shared_ptr<DataType> out_type() const { return boolean(); }
virtual ~CompareFunction() {}
};
/// \brief BinaryKernel bound to a select function
class ARROW_EXPORT CompareBinaryKernel : public BinaryKernel {
public:
explicit CompareBinaryKernel(std::shared_ptr<CompareFunction>& select)
: compare_function_(select) {}
Status Call(FunctionContext* ctx, const Datum& left, const Datum& right,
Datum* out) override;
static int64_t out_length(const Datum& left, const Datum& right) {
if (left.kind() == Datum::ARRAY) return left.length();
if (right.kind() == Datum::ARRAY) return right.length();
return 0;
}
std::shared_ptr<DataType> out_type() const override;
private:
std::shared_ptr<CompareFunction> compare_function_;
};
enum CompareOperator {
EQUAL,
NOT_EQUAL,
GREATER,
GREATER_EQUAL,
LESS,
LESS_EQUAL,
};
template <typename T, CompareOperator Op>
struct Comparator;
template <typename T>
struct Comparator<T, CompareOperator::EQUAL> {
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs == rhs; }
};
template <typename T>
struct Comparator<T, CompareOperator::NOT_EQUAL> {
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs != rhs; }
};
template <typename T>
struct Comparator<T, CompareOperator::GREATER> {
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs > rhs; }
};
template <typename T>
struct Comparator<T, CompareOperator::GREATER_EQUAL> {
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs >= rhs; }
};
template <typename T>
struct Comparator<T, CompareOperator::LESS> {
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs < rhs; }
};
template <typename T>
struct Comparator<T, CompareOperator::LESS_EQUAL> {
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs <= rhs; }
};
struct CompareOptions {
explicit CompareOptions(CompareOperator op) : op(op) {}
enum CompareOperator op;
};
/// \brief Return a Compare CompareFunction
///
/// \param[in] context FunctionContext passing context information
/// \param[in] type required to specialize the kernel
/// \param[in] options required to specify the compare operator
///
/// \since 0.14.0
/// \note API not yet finalized
ARROW_EXPORT
std::shared_ptr<CompareFunction> MakeCompareFunction(FunctionContext* context,
const DataType& type,
struct CompareOptions options);
/// \brief Compare a numeric array with a scalar.
///
/// \param[in] context the FunctionContext
/// \param[in] left datum to compare, must be an Array
/// \param[in] right datum to compare, must be a Scalar of the same type than
/// left Datum.
/// \param[in] options compare options
/// \param[out] out resulting datum
///
/// Note on floating point arrays, this uses ieee-754 compare semantics.
///
/// \since 0.14.0
/// \note API not yet finalized
ARROW_EXPORT
Status Compare(FunctionContext* context, const Datum& left, const Datum& right,
struct CompareOptions options, Datum* out);
} // namespace compute
} // namespace arrow

View File

@ -1,88 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <type_traits>
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
namespace compute {
struct Datum;
class FunctionContext;
class AggregateFunction;
/// \class CountOptions
///
/// The user control the Count kernel behavior with this class. By default, the
/// it will count all non-null values.
struct ARROW_EXPORT CountOptions {
enum mode {
// Count all non-null values.
COUNT_ALL = 0,
// Count all null values.
COUNT_NULL,
};
explicit CountOptions(enum mode count_mode) : count_mode(count_mode) {}
enum mode count_mode = COUNT_ALL;
};
/// \brief Return Count function aggregate
ARROW_EXPORT
std::shared_ptr<AggregateFunction> MakeCount(FunctionContext* context,
const CountOptions& options);
/// \brief Count non-null (or null) values in an array.
///
/// \param[in] context the FunctionContext
/// \param[in] options counting options, see CountOptions for more information
/// \param[in] datum to count
/// \param[out] out resulting datum
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Count(FunctionContext* context, const CountOptions& options, const Datum& datum,
Datum* out);
/// \brief Count non-null (or null) values in an array.
///
/// \param[in] context the FunctionContext
/// \param[in] options counting options, see CountOptions for more information
/// \param[in] array to count
/// \param[out] out resulting datum
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Count(FunctionContext* context, const CountOptions& options, const Array& array,
Datum* out);
} // namespace compute
} // namespace arrow

View File

@ -1,93 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/compute/kernel.h"
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
namespace compute {
class FunctionContext;
/// \brief Filter an array with a boolean selection filter
///
/// The output array will be populated with values from the input at positions
/// where the selection filter is not 0. Nulls in the filter will result in nulls
/// in the output.
///
/// For example given values = ["a", "b", "c", null, "e", "f"] and
/// filter = [0, 1, 1, 0, null, 1], the output will be
/// = ["b", "c", null, "f"]
///
/// \param[in] ctx the FunctionContext
/// \param[in] values array to filter
/// \param[in] filter indicates which values should be filtered out
/// \param[out] out resulting array
ARROW_EXPORT
Status Filter(FunctionContext* ctx, const Array& values, const Array& filter,
std::shared_ptr<Array>* out);
/// \brief Filter an array with a boolean selection filter
///
/// \param[in] ctx the FunctionContext
/// \param[in] values datum to filter
/// \param[in] filter indicates which values should be filtered out
/// \param[out] out resulting datum
ARROW_EXPORT
Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out);
/// \brief BinaryKernel implementing Filter operation
class ARROW_EXPORT FilterKernel : public BinaryKernel {
public:
explicit FilterKernel(const std::shared_ptr<DataType>& type) : type_(type) {}
/// \brief BinaryKernel interface
///
/// delegates to subclasses via Filter()
Status Call(FunctionContext* ctx, const Datum& values, const Datum& filter,
Datum* out) override;
/// \brief output type of this kernel (identical to type of values filtered)
std::shared_ptr<DataType> out_type() const override { return type_; }
/// \brief factory for FilterKernels
///
/// \param[in] value_type constructed FilterKernel will support filtering
/// values of this type
/// \param[out] out created kernel
static Status Make(const std::shared_ptr<DataType>& value_type,
std::unique_ptr<FilterKernel>* out);
/// \brief single-array implementation
virtual Status Filter(FunctionContext* ctx, const Array& values,
const BooleanArray& filter, int64_t length,
std::shared_ptr<Array>* out) = 0;
protected:
std::shared_ptr<DataType> type_;
};
} // namespace compute
} // namespace arrow

View File

@ -1,105 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_KERNELS_HASH_H
#define ARROW_COMPUTE_KERNELS_HASH_H
#include <memory>
#include "arrow/compute/kernel.h"
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
struct ArrayData;
namespace compute {
class FunctionContext;
/// \brief Compute unique elements from an array-like object
///
/// Note if a null occurs in the input it will NOT be included in the output.
///
/// \param[in] context the FunctionContext
/// \param[in] datum array-like input
/// \param[out] out result as Array
///
/// \since 0.8.0
/// \note API not yet finalized
ARROW_EXPORT
Status Unique(FunctionContext* context, const Datum& datum, std::shared_ptr<Array>* out);
// Constants for accessing the output of ValueCounts
ARROW_EXPORT extern const char kValuesFieldName[];
ARROW_EXPORT extern const char kCountsFieldName[];
ARROW_EXPORT extern const int32_t kValuesFieldIndex;
ARROW_EXPORT extern const int32_t kCountsFieldIndex;
/// \brief Return counts of unique elements from an array-like object.
///
/// Note that the counts do not include counts for nulls in the array. These can be
/// obtained separately from metadata.
///
/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values
/// which can lead to unexpected results if the input Array has these values.
///
/// \param[in] context the FunctionContext
/// \param[in] value array-like input
/// \param[out] counts An array of <input type "Values", int64_t "Counts"> structs.
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status ValueCounts(FunctionContext* context, const Datum& value,
std::shared_ptr<Array>* counts);
/// \brief Dictionary-encode values in an array-like object
/// \param[in] context the FunctionContext
/// \param[in] data array-like input
/// \param[out] out result with same shape and type as input
///
/// \since 0.8.0
/// \note API not yet finalized
ARROW_EXPORT
Status DictionaryEncode(FunctionContext* context, const Datum& data, Datum* out);
// TODO(wesm): Define API for incremental dictionary encoding
// TODO(wesm): Define API for regularizing DictionaryArray objects with
// different dictionaries
//
// ARROW_EXPORT
// Status DictionaryEncode(FunctionContext* context, const Datum& data,
// const Array& prior_dictionary, Datum* out);
// TODO(wesm): Implement these next
// ARROW_EXPORT
// Status Match(FunctionContext* context, const Datum& values, const Datum& member_set,
// Datum* out);
// ARROW_EXPORT
// Status IsIn(FunctionContext* context, const Datum& values, const Datum& member_set,
// Datum* out);
} // namespace compute
} // namespace arrow
#endif // ARROW_COMPUTE_KERNELS_HASH_H

View File

@ -1,66 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <type_traits>
#include "arrow/status.h"
#include "arrow/type.h"
#include "arrow/type_traits.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
namespace compute {
struct Datum;
class FunctionContext;
class AggregateFunction;
ARROW_EXPORT
std::shared_ptr<AggregateFunction> MakeMeanAggregateFunction(const DataType& type,
FunctionContext* context);
/// \brief Compute the mean of a numeric array.
///
/// \param[in] context the FunctionContext
/// \param[in] value datum to compute the mean, expecting Array
/// \param[out] mean datum of the computed mean as a DoubleScalar
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Mean(FunctionContext* context, const Datum& value, Datum* mean);
/// \brief Compute the mean of a numeric array.
///
/// \param[in] context the FunctionContext
/// \param[in] array to compute the mean
/// \param[out] mean datum of the computed mean as a DoubleScalar
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Mean(FunctionContext* context, const Array& array, Datum* mean);
} // namespace compute
}; // namespace arrow

View File

@ -1,70 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
class Status;
namespace compute {
struct Datum;
class FunctionContext;
class AggregateFunction;
/// \brief Return a Sum Kernel
///
/// \param[in] type required to specialize the kernel
/// \param[in] context the FunctionContext
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
FunctionContext* context);
/// \brief Sum values of a numeric array.
///
/// \param[in] context the FunctionContext
/// \param[in] value datum to sum, expecting Array or ChunkedArray
/// \param[out] out resulting datum
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Sum(FunctionContext* context, const Datum& value, Datum* out);
/// \brief Sum values of a numeric array.
///
/// \param[in] context the FunctionContext
/// \param[in] array to sum
/// \param[out] out resulting datum
///
/// \since 0.13.0
/// \note API not yet finalized
ARROW_EXPORT
Status Sum(FunctionContext* context, const Array& array, Datum* out);
} // namespace compute
} // namespace arrow

View File

@ -1,101 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/compute/kernel.h"
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
namespace compute {
class FunctionContext;
struct ARROW_EXPORT TakeOptions {};
/// \brief Take from an array of values at indices in another array
///
/// The output array will be of the same type as the input values
/// array, with elements taken from the values array at the given
/// indices. If an index is null then the taken element will be null.
///
/// For example given values = ["a", "b", "c", null, "e", "f"] and
/// indices = [2, 1, null, 3], the output will be
/// = [values[2], values[1], null, values[3]]
/// = ["c", "b", null, null]
///
/// \param[in] ctx the FunctionContext
/// \param[in] values array from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting array
ARROW_EXPORT
Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
const TakeOptions& options, std::shared_ptr<Array>* out);
/// \brief Take from an array of values at indices in another array
///
/// \param[in] ctx the FunctionContext
/// \param[in] values datum from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting datum
ARROW_EXPORT
Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices,
const TakeOptions& options, Datum* out);
/// \brief BinaryKernel implementing Take operation
class ARROW_EXPORT TakeKernel : public BinaryKernel {
public:
explicit TakeKernel(const std::shared_ptr<DataType>& type, TakeOptions options = {})
: type_(type) {}
/// \brief BinaryKernel interface
///
/// delegates to subclasses via Take()
Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices,
Datum* out) override;
/// \brief output type of this kernel (identical to type of values taken)
std::shared_ptr<DataType> out_type() const override { return type_; }
/// \brief factory for TakeKernels
///
/// \param[in] value_type constructed TakeKernel will support taking
/// values of this type
/// \param[in] index_type constructed TakeKernel will support taking
/// with indices of this type
/// \param[out] out created kernel
static Status Make(const std::shared_ptr<DataType>& value_type,
const std::shared_ptr<DataType>& index_type,
std::unique_ptr<TakeKernel>* out);
/// \brief single-array implementation
virtual Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
std::shared_ptr<Array>* out) = 0;
protected:
std::shared_ptr<DataType> type_;
};
} // namespace compute
} // namespace arrow

View File

@ -1,308 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
// Metadata objects for creating well-typed expressions. These are distinct
// from (and higher level than) arrow::DataType as some type parameters (like
// decimal scale and precision) may not be known at expression build time, and
// these are resolved later on evaluation
#pragma once
#include <memory>
#include <string>
#include "arrow/compute/type_fwd.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Status;
namespace compute {
class Expr;
/// \brief An object that represents either a single concrete value type or a
/// group of related types, to help with expression type validation and other
/// purposes
class ARROW_EXPORT LogicalType {
public:
enum Id {
ANY,
NUMBER,
INTEGER,
SIGNED_INTEGER,
UNSIGNED_INTEGER,
FLOATING,
NULL_,
BOOL,
UINT8,
INT8,
UINT16,
INT16,
UINT32,
INT32,
UINT64,
INT64,
FLOAT16,
FLOAT32,
FLOAT64,
BINARY,
UTF8,
DATE,
TIME,
TIMESTAMP,
DECIMAL,
LIST,
STRUCT
};
Id id() const { return id_; }
virtual ~LogicalType() = default;
virtual std::string ToString() const = 0;
/// \brief Check if expression is an instance of this type class
virtual bool IsInstance(const Expr& expr) const = 0;
/// \brief Get a logical expression type from a concrete Arrow in-memory
/// array type
static Status FromArrow(const ::arrow::DataType& type, LogicalTypePtr* out);
protected:
explicit LogicalType(Id id) : id_(id) {}
Id id_;
};
namespace type {
/// \brief Logical type for any value type
class ARROW_EXPORT Any : public LogicalType {
public:
Any() : LogicalType(LogicalType::ANY) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for null
class ARROW_EXPORT Null : public LogicalType {
public:
Null() : LogicalType(LogicalType::NULL_) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for concrete boolean
class ARROW_EXPORT Bool : public LogicalType {
public:
Bool() : LogicalType(LogicalType::BOOL) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for any number (integer or floating point)
class ARROW_EXPORT Number : public LogicalType {
public:
Number() : Number(LogicalType::NUMBER) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
protected:
explicit Number(Id type_id) : LogicalType(type_id) {}
};
/// \brief Logical type for any integer
class ARROW_EXPORT Integer : public Number {
public:
Integer() : Integer(LogicalType::INTEGER) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
protected:
explicit Integer(Id type_id) : Number(type_id) {}
};
/// \brief Logical type for any floating point number
class ARROW_EXPORT Floating : public Number {
public:
Floating() : Floating(LogicalType::FLOATING) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
protected:
explicit Floating(Id type_id) : Number(type_id) {}
};
/// \brief Logical type for any signed integer
class ARROW_EXPORT SignedInteger : public Integer {
public:
SignedInteger() : SignedInteger(LogicalType::SIGNED_INTEGER) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
protected:
explicit SignedInteger(Id type_id) : Integer(type_id) {}
};
/// \brief Logical type for any unsigned integer
class ARROW_EXPORT UnsignedInteger : public Integer {
public:
UnsignedInteger() : UnsignedInteger(LogicalType::UNSIGNED_INTEGER) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
protected:
explicit UnsignedInteger(Id type_id) : Integer(type_id) {}
};
/// \brief Logical type for int8
class ARROW_EXPORT Int8 : public SignedInteger {
public:
Int8() : SignedInteger(LogicalType::INT8) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for int16
class ARROW_EXPORT Int16 : public SignedInteger {
public:
Int16() : SignedInteger(LogicalType::INT16) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for int32
class ARROW_EXPORT Int32 : public SignedInteger {
public:
Int32() : SignedInteger(LogicalType::INT32) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for int64
class ARROW_EXPORT Int64 : public SignedInteger {
public:
Int64() : SignedInteger(LogicalType::INT64) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for uint8
class ARROW_EXPORT UInt8 : public UnsignedInteger {
public:
UInt8() : UnsignedInteger(LogicalType::UINT8) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for uint16
class ARROW_EXPORT UInt16 : public UnsignedInteger {
public:
UInt16() : UnsignedInteger(LogicalType::UINT16) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for uint32
class ARROW_EXPORT UInt32 : public UnsignedInteger {
public:
UInt32() : UnsignedInteger(LogicalType::UINT32) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for uint64
class ARROW_EXPORT UInt64 : public UnsignedInteger {
public:
UInt64() : UnsignedInteger(LogicalType::UINT64) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for 16-bit floating point
class ARROW_EXPORT Float16 : public Floating {
public:
Float16() : Floating(LogicalType::FLOAT16) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for 32-bit floating point
class ARROW_EXPORT Float32 : public Floating {
public:
Float32() : Floating(LogicalType::FLOAT32) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for 64-bit floating point
class ARROW_EXPORT Float64 : public Floating {
public:
Float64() : Floating(LogicalType::FLOAT64) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
/// \brief Logical type for variable-size binary
class ARROW_EXPORT Binary : public LogicalType {
public:
Binary() : Binary(LogicalType::BINARY) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
protected:
explicit Binary(Id type_id) : LogicalType(type_id) {}
};
/// \brief Logical type for variable-size binary
class ARROW_EXPORT Utf8 : public Binary {
public:
Utf8() : Binary(LogicalType::UTF8) {}
bool IsInstance(const Expr& expr) const override;
std::string ToString() const override;
};
#define SIMPLE_TYPE_FACTORY(NAME) ARROW_EXPORT LogicalTypePtr NAME();
SIMPLE_TYPE_FACTORY(any);
SIMPLE_TYPE_FACTORY(null);
SIMPLE_TYPE_FACTORY(boolean);
SIMPLE_TYPE_FACTORY(number);
SIMPLE_TYPE_FACTORY(integer);
SIMPLE_TYPE_FACTORY(signed_integer);
SIMPLE_TYPE_FACTORY(unsigned_integer);
SIMPLE_TYPE_FACTORY(floating);
SIMPLE_TYPE_FACTORY(int8);
SIMPLE_TYPE_FACTORY(int16);
SIMPLE_TYPE_FACTORY(int32);
SIMPLE_TYPE_FACTORY(int64);
SIMPLE_TYPE_FACTORY(uint8);
SIMPLE_TYPE_FACTORY(uint16);
SIMPLE_TYPE_FACTORY(uint32);
SIMPLE_TYPE_FACTORY(uint64);
SIMPLE_TYPE_FACTORY(float16);
SIMPLE_TYPE_FACTORY(float32);
SIMPLE_TYPE_FACTORY(float64);
SIMPLE_TYPE_FACTORY(binary);
SIMPLE_TYPE_FACTORY(utf8);
#undef SIMPLE_TYPE_FACTORY
} // namespace type
} // namespace compute
} // namespace arrow

View File

@ -1,52 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include <vector>
#include "arrow/compute/type_fwd.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Status;
namespace compute {
/// \brief An operation is a node in a computation graph, taking input data
/// expression dependencies and emitting an output expression
class ARROW_EXPORT Operation : public std::enable_shared_from_this<Operation> {
public:
virtual ~Operation() = default;
/// \brief Check input expression arguments and output the type of resulting
/// expression that this operation produces. If the input arguments are
/// invalid, error Status is returned
/// \param[out] out the returned well-typed expression
/// \return success or failure
virtual Status ToExpr(ExprPtr* out) const = 0;
/// \brief Return the input expressions used to instantiate the
/// operation. The default implementation returns an empty vector
/// \return a vector of expressions
virtual std::vector<ExprPtr> input_args() const;
};
} // namespace compute
} // namespace arrow

View File

@ -1,110 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_COMPUTE_TEST_UTIL_H
#define ARROW_COMPUTE_TEST_UTIL_H
#include <memory>
#include <vector>
#include <gmock/gmock.h>
#include "arrow/array.h"
#include "arrow/memory_pool.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/type.h"
#include "arrow/compute/context.h"
#include "arrow/compute/kernel.h"
namespace arrow {
namespace compute {
class ComputeFixture {
public:
ComputeFixture() : ctx_(default_memory_pool()) {}
protected:
FunctionContext ctx_;
};
class MockUnaryKernel : public UnaryKernel {
public:
MOCK_METHOD3(Call, Status(FunctionContext* ctx, const Datum& input, Datum* out));
MOCK_CONST_METHOD0(out_type, std::shared_ptr<DataType>());
};
class MockBinaryKernel : public BinaryKernel {
public:
MOCK_METHOD4(Call, Status(FunctionContext* ctx, const Datum& left, const Datum& right,
Datum* out));
MOCK_CONST_METHOD0(out_type, std::shared_ptr<DataType>());
};
template <typename Type, typename T>
std::shared_ptr<Array> _MakeArray(const std::shared_ptr<DataType>& type,
const std::vector<T>& values,
const std::vector<bool>& is_valid) {
std::shared_ptr<Array> result;
if (is_valid.size() > 0) {
ArrayFromVector<Type, T>(type, is_valid, values, &result);
} else {
ArrayFromVector<Type, T>(type, values, &result);
}
return result;
}
template <typename Type, typename Enable = void>
struct DatumEqual {};
template <typename Type>
struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::value>::type> {
static constexpr double kArbitraryDoubleErrorBound = 1.0;
using ScalarType = typename TypeTraits<Type>::ScalarType;
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
ASSERT_EQ(lhs.kind(), rhs.kind());
if (lhs.kind() == Datum::SCALAR) {
auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
ASSERT_EQ(left->is_valid, right->is_valid);
ASSERT_EQ(left->type->id(), right->type->id());
ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
}
}
};
template <typename Type>
struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
using ScalarType = typename TypeTraits<Type>::ScalarType;
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
ASSERT_EQ(lhs.kind(), rhs.kind());
if (lhs.kind() == Datum::SCALAR) {
auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
ASSERT_EQ(left->is_valid, right->is_valid);
ASSERT_EQ(left->type->id(), right->type->id());
ASSERT_EQ(left->value, right->value);
}
}
};
} // namespace compute
} // namespace arrow
#endif

View File

@ -1,38 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include <memory>
#include "arrow/type_fwd.h"
namespace arrow {
namespace compute {
class Expr;
class LogicalType;
class Operation;
using ArrowTypePtr = std::shared_ptr<::arrow::DataType>;
using ExprPtr = std::shared_ptr<Expr>;
using ConstOpPtr = std::shared_ptr<const Operation>;
using OpPtr = std::shared_ptr<Operation>;
using LogicalTypePtr = std::shared_ptr<LogicalType>;
} // namespace compute
} // namespace arrow

View File

@ -1,24 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_API_H
#define ARROW_CSV_API_H
#include "arrow/csv/options.h"
#include "arrow/csv/reader.h"
#endif // ARROW_CSV_API_H

View File

@ -1,69 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_CHUNKER_H
#define ARROW_CSV_CHUNKER_H
#include <cstdint>
#include "arrow/csv/options.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
namespace arrow {
namespace csv {
/// \class Chunker
/// \brief A reusable block-based chunker for CSV data
///
/// The chunker takes a block of CSV data and finds a suitable place
/// to cut it up without splitting a row.
/// If the block is truncated (i.e. not all data can be chunked), it is up
/// to the caller to arrange the next block to start with the trailing data.
///
/// Note: if the previous block ends with CR (0x0d) and a new block starts
/// with LF (0x0a), the chunker will consider the leading newline as an empty line.
class ARROW_EXPORT Chunker {
public:
explicit Chunker(ParseOptions options);
/// \brief Carve up a chunk in a block of data
///
/// Process a block of CSV data, reading up to size bytes.
/// The number of bytes in the chunk is returned in out_size.
Status Process(const char* data, uint32_t size, uint32_t* out_size);
protected:
ARROW_DISALLOW_COPY_AND_ASSIGN(Chunker);
// Like Process(), but specialized for some parsing options
template <bool quoting, bool escaping>
Status ProcessSpecialized(const char* data, uint32_t size, uint32_t* out_size);
// Detect a single line from the data pointer. Return the line end,
// or nullptr if the remaining line is truncated.
template <bool quoting, bool escaping>
inline const char* ReadLine(const char* data, const char* data_end);
ParseOptions options_;
};
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_CHUNKER_H

View File

@ -1,87 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_COLUMN_BUILDER_H
#define ARROW_CSV_COLUMN_BUILDER_H
#include <cstdint>
#include <memory>
#include "arrow/array.h"
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class ChunkedArray;
class DataType;
namespace internal {
class TaskGroup;
} // namespace internal
namespace csv {
class BlockParser;
struct ConvertOptions;
class ARROW_EXPORT ColumnBuilder {
public:
virtual ~ColumnBuilder() = default;
/// Spawn a task that will try to convert and append the given CSV block.
/// All calls to Append() should happen on the same thread, otherwise
/// call Insert() instead.
virtual void Append(const std::shared_ptr<BlockParser>& parser);
/// Spawn a task that will try to convert and insert the given CSV block
virtual void Insert(int64_t block_index,
const std::shared_ptr<BlockParser>& parser) = 0;
/// Return the final chunked array. The TaskGroup _must_ have finished!
virtual Status Finish(std::shared_ptr<ChunkedArray>* out) = 0;
/// Change the task group. The previous TaskGroup _must_ have finished!
void SetTaskGroup(const std::shared_ptr<internal::TaskGroup>& task_group);
std::shared_ptr<internal::TaskGroup> task_group() { return task_group_; }
/// Construct a strictly-typed ColumnBuilder.
static Status Make(const std::shared_ptr<DataType>& type, int32_t col_index,
const ConvertOptions& options,
const std::shared_ptr<internal::TaskGroup>& task_group,
std::shared_ptr<ColumnBuilder>* out);
/// Construct a type-inferring ColumnBuilder.
static Status Make(int32_t col_index, const ConvertOptions& options,
const std::shared_ptr<internal::TaskGroup>& task_group,
std::shared_ptr<ColumnBuilder>* out);
protected:
explicit ColumnBuilder(const std::shared_ptr<internal::TaskGroup>& task_group)
: task_group_(task_group) {}
std::shared_ptr<internal::TaskGroup> task_group_;
ArrayVector chunks_;
};
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_COLUMN_BUILDER_H

View File

@ -1,68 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_CONVERTER_H
#define ARROW_CSV_CONVERTER_H
#include <cstdint>
#include <memory>
#include "arrow/csv/options.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
namespace arrow {
class Array;
class DataType;
class MemoryPool;
class Status;
namespace csv {
class BlockParser;
class ARROW_EXPORT Converter {
public:
Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
MemoryPool* pool);
virtual ~Converter() = default;
virtual Status Convert(const BlockParser& parser, int32_t col_index,
std::shared_ptr<Array>* out) = 0;
std::shared_ptr<DataType> type() const { return type_; }
static Status Make(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
std::shared_ptr<Converter>* out);
static Status Make(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
MemoryPool* pool, std::shared_ptr<Converter>* out);
protected:
ARROW_DISALLOW_COPY_AND_ASSIGN(Converter);
virtual Status Initialize() = 0;
const ConvertOptions options_;
MemoryPool* pool_;
std::shared_ptr<DataType> type_;
};
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_CONVERTER_H

View File

@ -1,98 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_OPTIONS_H
#define ARROW_CSV_OPTIONS_H
#include <cstdint>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "arrow/util/visibility.h"
namespace arrow {
class DataType;
namespace csv {
struct ARROW_EXPORT ParseOptions {
// Parsing options
// Field delimiter
char delimiter = ',';
// Whether quoting is used
bool quoting = true;
// Quoting character (if `quoting` is true)
char quote_char = '"';
// Whether a quote inside a value is double-quoted
bool double_quote = true;
// Whether escaping is used
bool escaping = false;
// Escaping character (if `escaping` is true)
char escape_char = '\\';
// Whether values are allowed to contain CR (0x0d) and LF (0x0a) characters
bool newlines_in_values = false;
// Whether empty lines are ignored. If false, an empty line represents
// a single empty value (assuming a one-column CSV file).
bool ignore_empty_lines = true;
// XXX Should this be in ReadOptions?
// Number of header rows to skip (including the first row containing column names)
int32_t header_rows = 1;
static ParseOptions Defaults();
};
struct ARROW_EXPORT ConvertOptions {
// Conversion options
// Whether to check UTF8 validity of string columns
bool check_utf8 = true;
// Optional per-column types (disabling type inference on those columns)
std::unordered_map<std::string, std::shared_ptr<DataType>> column_types;
// Recognized spellings for null values
std::vector<std::string> null_values;
// Recognized spellings for boolean values
std::vector<std::string> true_values;
std::vector<std::string> false_values;
// Whether string / binary columns can have null values.
// If true, then strings in "null_values" are considered null for string columns.
// If false, then all strings are valid string values.
bool strings_can_be_null = false;
static ConvertOptions Defaults();
};
struct ARROW_EXPORT ReadOptions {
// Reader options
// Whether to use the global CPU thread pool
bool use_threads = true;
// Block size we request from the IO layer; also determines the size of
// chunks when use_threads is true
int32_t block_size = 1 << 20; // 1 MB
static ReadOptions Defaults();
};
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_OPTIONS_H

View File

@ -1,149 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_PARSER_H
#define ARROW_CSV_PARSER_H
#include <cstddef>
#include <cstdint>
#include <memory>
#include <vector>
#include "arrow/buffer.h"
#include "arrow/csv/options.h"
#include "arrow/status.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
namespace arrow {
class MemoryPool;
namespace csv {
constexpr int32_t kMaxParserNumRows = 100000;
/// \class BlockParser
/// \brief A reusable block-based parser for CSV data
///
/// The parser takes a block of CSV data and delimits rows and fields,
/// unquoting and unescaping them on the fly. Parsed data is own by the
/// parser, so the original buffer can be discarded after Parse() returns.
///
/// If the block is truncated (i.e. not all data can be parsed), it is up
/// to the caller to arrange the next block to start with the trailing data.
/// Also, if the previous block ends with CR (0x0d) and a new block starts
/// with LF (0x0a), the parser will consider the leading newline as an empty
/// line; the caller should therefore strip it.
class ARROW_EXPORT BlockParser {
public:
explicit BlockParser(ParseOptions options, int32_t num_cols = -1,
int32_t max_num_rows = kMaxParserNumRows);
explicit BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols = -1,
int32_t max_num_rows = kMaxParserNumRows);
/// \brief Parse a block of data
///
/// Parse a block of CSV data, ingesting up to max_num_rows rows.
/// The number of bytes actually parsed is returned in out_size.
Status Parse(const char* data, uint32_t size, uint32_t* out_size);
/// \brief Parse the final block of data
///
/// Like Parse(), but called with the final block in a file.
/// The last row may lack a trailing line separator.
Status ParseFinal(const char* data, uint32_t size, uint32_t* out_size);
/// \brief Return the number of parsed rows
int32_t num_rows() const { return num_rows_; }
/// \brief Return the number of parsed columns
int32_t num_cols() const { return num_cols_; }
/// \brief Return the total size in bytes of parsed data
uint32_t num_bytes() const { return parsed_size_; }
/// \brief Visit parsed values in a column
///
/// The signature of the visitor is
/// Status(const uint8_t* data, uint32_t size, bool quoted)
template <typename Visitor>
Status VisitColumn(int32_t col_index, Visitor&& visit) const {
for (size_t buf_index = 0; buf_index < values_buffers_.size(); ++buf_index) {
const auto& values_buffer = values_buffers_[buf_index];
const auto values = reinterpret_cast<const ValueDesc*>(values_buffer->data());
const auto max_pos =
static_cast<int32_t>(values_buffer->size() / sizeof(ValueDesc)) - 1;
for (int32_t pos = col_index; pos < max_pos; pos += num_cols_) {
auto start = values[pos].offset;
auto stop = values[pos + 1].offset;
auto quoted = values[pos + 1].quoted;
ARROW_RETURN_NOT_OK(visit(parsed_ + start, stop - start, quoted));
}
}
return Status::OK();
}
protected:
ARROW_DISALLOW_COPY_AND_ASSIGN(BlockParser);
Status DoParse(const char* data, uint32_t size, bool is_final, uint32_t* out_size);
template <typename SpecializedOptions>
Status DoParseSpecialized(const char* data, uint32_t size, bool is_final,
uint32_t* out_size);
template <typename SpecializedOptions, typename ValuesWriter, typename ParsedWriter>
Status ParseChunk(ValuesWriter* values_writer, ParsedWriter* parsed_writer,
const char* data, const char* data_end, bool is_final,
int32_t rows_in_chunk, const char** out_data, bool* finished_parsing);
// Parse a single line from the data pointer
template <typename SpecializedOptions, typename ValuesWriter, typename ParsedWriter>
Status ParseLine(ValuesWriter* values_writer, ParsedWriter* parsed_writer,
const char* data, const char* data_end, bool is_final,
const char** out_data);
MemoryPool* pool_;
const ParseOptions options_;
// The number of rows parsed from the block
int32_t num_rows_;
// The number of columns (can be -1 at start)
int32_t num_cols_;
// The maximum number of rows to parse from this block
int32_t max_num_rows_;
// Linear scratchpad for parsed values
struct ValueDesc {
uint32_t offset : 31;
bool quoted : 1;
};
// XXX should we ensure the parsed buffer is padded with 8 or 16 excess zero bytes?
// It may help with null parsing...
std::vector<std::shared_ptr<Buffer>> values_buffers_;
std::shared_ptr<Buffer> parsed_buffer_;
const uint8_t* parsed_;
int32_t values_size_;
int32_t parsed_size_;
class ResizableValuesWriter;
class PresizedValuesWriter;
class PresizedParsedWriter;
};
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_PARSER_H

View File

@ -1,53 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_READER_H
#define ARROW_CSV_READER_H
#include <memory>
#include "arrow/csv/options.h" // IWYU pragma: keep
#include "arrow/status.h"
#include "arrow/util/visibility.h"
namespace arrow {
class MemoryPool;
class Table;
namespace io {
class InputStream;
} // namespace io
namespace csv {
class ARROW_EXPORT TableReader {
public:
virtual ~TableReader() = default;
virtual Status Read(std::shared_ptr<Table>* out) = 0;
// XXX pass optional schema?
static Status Make(MemoryPool* pool, std::shared_ptr<io::InputStream> input,
const ReadOptions&, const ParseOptions&, const ConvertOptions&,
std::shared_ptr<TableReader>* out);
};
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_READER_H

View File

@ -1,71 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#ifndef ARROW_CSV_TEST_COMMON_H
#define ARROW_CSV_TEST_COMMON_H
#include <memory>
#include <string>
#include <vector>
#include "arrow/csv/parser.h"
#include "arrow/testing/gtest_util.h"
namespace arrow {
namespace csv {
std::string MakeCSVData(std::vector<std::string> lines) {
std::string s;
for (const auto& line : lines) {
s += line;
}
return s;
}
// Make a BlockParser from a vector of lines representing a CSV file
void MakeCSVParser(std::vector<std::string> lines, ParseOptions options,
std::shared_ptr<BlockParser>* out) {
auto csv = MakeCSVData(lines);
auto parser = std::make_shared<BlockParser>(options);
uint32_t out_size;
ASSERT_OK(parser->Parse(csv.data(), static_cast<uint32_t>(csv.size()), &out_size));
ASSERT_EQ(out_size, csv.size()) << "trailing CSV data not parsed";
*out = parser;
}
void MakeCSVParser(std::vector<std::string> lines, std::shared_ptr<BlockParser>* out) {
MakeCSVParser(lines, ParseOptions::Defaults(), out);
}
// Make a BlockParser from a vector of strings representing a single CSV column
void MakeColumnParser(std::vector<std::string> items, std::shared_ptr<BlockParser>* out) {
auto options = ParseOptions::Defaults();
// Need this to test for null (empty) values
options.ignore_empty_lines = false;
std::vector<std::string> lines;
for (const auto& item : items) {
lines.push_back(item + '\n');
}
MakeCSVParser(lines, options, out);
ASSERT_EQ((*out)->num_cols(), 1) << "Should have seen only 1 CSV column";
ASSERT_EQ((*out)->num_rows(), items.size());
}
} // namespace csv
} // namespace arrow
#endif // ARROW_CSV_TEST_COMMON_H

View File

@ -1,26 +0,0 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#pragma once
#include "arrow/dataset/dataset.h"
#include "arrow/dataset/discovery.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_csv.h"
#include "arrow/dataset/file_feather.h"
#include "arrow/dataset/file_parquet.h"
#include "arrow/dataset/scanner.h"

Some files were not shown because too many files have changed in this diff Show More