mirror of
https://gitee.com/milvus-io/milvus.git
synced 2024-11-30 19:08:30 +08:00
delete knowhere
Former-commit-id: c04ad4797de102962ee39c6cf364926a120ddcd8
This commit is contained in:
parent
fe165a6165
commit
cc5ab807fc
@ -1,67 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
|
||||
#define _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
|
||||
|
||||
#include "inc/Socket/Common.h"
|
||||
#include "AggregatorSettings.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <atomic>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Aggregator
|
||||
{
|
||||
|
||||
enum RemoteMachineStatus : uint8_t
|
||||
{
|
||||
Disconnected = 0,
|
||||
|
||||
Connecting,
|
||||
|
||||
Connected
|
||||
};
|
||||
|
||||
|
||||
struct RemoteMachine
|
||||
{
|
||||
RemoteMachine();
|
||||
|
||||
std::string m_address;
|
||||
|
||||
std::string m_port;
|
||||
|
||||
Socket::ConnectionID m_connectionID;
|
||||
|
||||
std::atomic<RemoteMachineStatus> m_status;
|
||||
};
|
||||
|
||||
class AggregatorContext
|
||||
{
|
||||
public:
|
||||
AggregatorContext(const std::string& p_filePath);
|
||||
|
||||
~AggregatorContext();
|
||||
|
||||
bool IsInitialized() const;
|
||||
|
||||
const std::vector<std::shared_ptr<RemoteMachine>>& GetRemoteServers() const;
|
||||
|
||||
const std::shared_ptr<AggregatorSettings>& GetSettings() const;
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<RemoteMachine>> m_remoteServers;
|
||||
|
||||
std::shared_ptr<AggregatorSettings> m_settings;
|
||||
|
||||
bool m_initialized;
|
||||
};
|
||||
|
||||
} // namespace Aggregator
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_AGGREGATOR_AGGREGATORCONTEXT_H_
|
@ -1,53 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
|
||||
#define _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
|
||||
|
||||
#include "inc/Socket/RemoteSearchQuery.h"
|
||||
#include "inc/Socket/Packet.h"
|
||||
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Aggregator
|
||||
{
|
||||
|
||||
typedef std::shared_ptr<Socket::RemoteSearchResult> AggregatorResult;
|
||||
|
||||
class AggregatorExecutionContext
|
||||
{
|
||||
public:
|
||||
AggregatorExecutionContext(std::size_t p_totalServerNumber,
|
||||
Socket::PacketHeader p_requestHeader);
|
||||
|
||||
~AggregatorExecutionContext();
|
||||
|
||||
std::size_t GetServerNumber() const;
|
||||
|
||||
AggregatorResult& GetResult(std::size_t p_num);
|
||||
|
||||
const Socket::PacketHeader& GetRequestHeader() const;
|
||||
|
||||
bool IsCompletedAfterFinsh(std::uint32_t p_finishedCount);
|
||||
|
||||
private:
|
||||
std::atomic<std::uint32_t> m_unfinishedCount;
|
||||
|
||||
std::vector<AggregatorResult> m_results;
|
||||
|
||||
Socket::PacketHeader m_requestHeader;
|
||||
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace Aggregator
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_AGGREGATOR_AGGREGATOREXECUTIONCONTEXT_H_
|
||||
|
@ -1,88 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
|
||||
#define _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
|
||||
|
||||
#include "AggregatorContext.h"
|
||||
#include "AggregatorExecutionContext.h"
|
||||
#include "inc/Socket/Server.h"
|
||||
#include "inc/Socket/Client.h"
|
||||
#include "inc/Socket/ResourceManager.h"
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <condition_variable>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Aggregator
|
||||
{
|
||||
|
||||
class AggregatorService
|
||||
{
|
||||
public:
|
||||
AggregatorService();
|
||||
|
||||
~AggregatorService();
|
||||
|
||||
bool Initialize();
|
||||
|
||||
void Run();
|
||||
|
||||
private:
|
||||
|
||||
void StartClient();
|
||||
|
||||
void StartListen();
|
||||
|
||||
void WaitForShutdown();
|
||||
|
||||
void ConnectToPendingServers();
|
||||
|
||||
void AddToPendingServers(std::shared_ptr<RemoteMachine> p_remoteServer);
|
||||
|
||||
void SearchRequestHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
|
||||
|
||||
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
|
||||
|
||||
void AggregateResults(std::shared_ptr<AggregatorExecutionContext> p_exectionContext);
|
||||
|
||||
std::shared_ptr<AggregatorContext> GetContext();
|
||||
|
||||
private:
|
||||
typedef std::function<void(Socket::RemoteSearchResult)> AggregatorCallback;
|
||||
|
||||
std::shared_ptr<AggregatorContext> m_aggregatorContext;
|
||||
|
||||
std::shared_ptr<Socket::Server> m_socketServer;
|
||||
|
||||
std::shared_ptr<Socket::Client> m_socketClient;
|
||||
|
||||
bool m_initalized;
|
||||
|
||||
std::unique_ptr<boost::asio::thread_pool> m_threadPool;
|
||||
|
||||
boost::asio::io_context m_ioContext;
|
||||
|
||||
boost::asio::signal_set m_shutdownSignals;
|
||||
|
||||
std::vector<std::shared_ptr<RemoteMachine>> m_pendingConnectServers;
|
||||
|
||||
std::mutex m_pendingConnectServersMutex;
|
||||
|
||||
boost::asio::deadline_timer m_pendingConnectServersTimer;
|
||||
|
||||
Socket::ResourceManager<AggregatorCallback> m_aggregatorCallbackManager;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace Aggregator
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_AGGREGATOR_AGGREGATORSERVICE_H_
|
@ -1,39 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
|
||||
#define _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
|
||||
|
||||
#include "../Core/Common.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Aggregator
|
||||
{
|
||||
|
||||
struct AggregatorSettings
|
||||
{
|
||||
AggregatorSettings();
|
||||
|
||||
std::string m_listenAddr;
|
||||
|
||||
std::string m_listenPort;
|
||||
|
||||
std::uint32_t m_searchTimeout;
|
||||
|
||||
SizeType m_threadNum;
|
||||
|
||||
SizeType m_socketThreadNum;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace Aggregator
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_AGGREGATOR_AGGREGATORSETTINGS_H_
|
||||
|
@ -1,80 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_CLIENT_CLIENTWRAPPER_H_
|
||||
#define _SPTAG_CLIENT_CLIENTWRAPPER_H_
|
||||
|
||||
#include "inc/Socket/Client.h"
|
||||
#include "inc/Socket/RemoteSearchQuery.h"
|
||||
#include "inc/Socket/ResourceManager.h"
|
||||
#include "Options.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Client
|
||||
{
|
||||
|
||||
class ClientWrapper
|
||||
{
|
||||
public:
|
||||
typedef std::function<void(Socket::RemoteSearchResult)> Callback;
|
||||
|
||||
ClientWrapper(const ClientOptions& p_options);
|
||||
|
||||
~ClientWrapper();
|
||||
|
||||
void SendQueryAsync(const Socket::RemoteQuery& p_query,
|
||||
Callback p_callback,
|
||||
const ClientOptions& p_options);
|
||||
|
||||
void WaitAllFinished();
|
||||
|
||||
bool IsAvailable() const;
|
||||
|
||||
private:
|
||||
typedef std::pair<Socket::ConnectionID, Socket::ConnectionID> ConnectionPair;
|
||||
|
||||
Socket::PacketHandlerMapPtr GetHandlerMap();
|
||||
|
||||
void DecreaseUnfnishedJobCount();
|
||||
|
||||
const ConnectionPair& GetConnection();
|
||||
|
||||
void SearchResponseHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
|
||||
|
||||
void HandleDeadConnection(Socket::ConnectionID p_cid);
|
||||
|
||||
private:
|
||||
ClientOptions m_options;
|
||||
|
||||
std::unique_ptr<Socket::Client> m_client;
|
||||
|
||||
std::atomic<std::uint32_t> m_unfinishedJobCount;
|
||||
|
||||
std::atomic_bool m_isWaitingFinish;
|
||||
|
||||
std::condition_variable m_waitingQueue;
|
||||
|
||||
std::mutex m_waitingMutex;
|
||||
|
||||
std::vector<ConnectionPair> m_connections;
|
||||
|
||||
std::atomic<std::uint32_t> m_spinCountOfConnection;
|
||||
|
||||
Socket::ResourceManager<Callback> m_callbackManager;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_CLIENT_OPTIONS_H_
|
@ -1,42 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_CLIENT_OPTIONS_H_
|
||||
#define _SPTAG_CLIENT_OPTIONS_H_
|
||||
|
||||
#include "inc/Helper/ArgumentsParser.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Client
|
||||
{
|
||||
|
||||
class ClientOptions : public Helper::ArgumentsParser
|
||||
{
|
||||
public:
|
||||
ClientOptions();
|
||||
|
||||
virtual ~ClientOptions();
|
||||
|
||||
std::string m_serverAddr;
|
||||
|
||||
std::string m_serverPort;
|
||||
|
||||
// in milliseconds.
|
||||
std::uint32_t m_searchTimeout;
|
||||
|
||||
std::uint32_t m_threadNum;
|
||||
|
||||
std::uint32_t m_socketThreadNum;
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_CLIENT_OPTIONS_H_
|
@ -1,112 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_BKT_INDEX_H_
|
||||
#define _SPTAG_BKT_INDEX_H_
|
||||
|
||||
#include "../Common.h"
|
||||
#include "../VectorIndex.h"
|
||||
|
||||
#include "../Common/CommonUtils.h"
|
||||
#include "../Common/DistanceUtils.h"
|
||||
#include "../Common/QueryResultSet.h"
|
||||
#include "../Common/Dataset.h"
|
||||
#include "../Common/WorkSpace.h"
|
||||
#include "../Common/WorkSpacePool.h"
|
||||
#include "../Common/RelativeNeighborhoodGraph.h"
|
||||
#include "../Common/BKTree.h"
|
||||
#include "inc/Helper/SimpleIniReader.h"
|
||||
#include "inc/Helper/StringConvert.h"
|
||||
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <tbb/concurrent_unordered_set.h>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
namespace Helper
|
||||
{
|
||||
class IniReader;
|
||||
}
|
||||
|
||||
namespace BKT
|
||||
{
|
||||
template<typename T>
|
||||
class Index : public VectorIndex
|
||||
{
|
||||
private:
|
||||
// data points
|
||||
COMMON::Dataset<T> m_pSamples;
|
||||
|
||||
// BKT structures.
|
||||
COMMON::BKTree m_pTrees;
|
||||
|
||||
// Graph structure
|
||||
COMMON::RelativeNeighborhoodGraph m_pGraph;
|
||||
|
||||
std::string m_sBKTFilename;
|
||||
std::string m_sGraphFilename;
|
||||
std::string m_sDataPointsFilename;
|
||||
|
||||
std::mutex m_dataLock; // protect data and graph
|
||||
tbb::concurrent_unordered_set<int> m_deletedID;
|
||||
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
|
||||
|
||||
int m_iNumberOfThreads;
|
||||
DistCalcMethod m_iDistCalcMethod;
|
||||
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
|
||||
|
||||
int m_iMaxCheck;
|
||||
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
|
||||
int m_iNumberOfInitialDynamicPivots;
|
||||
int m_iNumberOfOtherDynamicPivots;
|
||||
public:
|
||||
Index()
|
||||
{
|
||||
#define DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr) \
|
||||
VarName = DefaultValue; \
|
||||
|
||||
#include "inc/Core/BKT/ParameterDefinitionList.h"
|
||||
#undef DefineBKTParameter
|
||||
|
||||
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
|
||||
}
|
||||
|
||||
~Index() {}
|
||||
|
||||
inline int GetNumSamples() const { return m_pSamples.R(); }
|
||||
inline int GetFeatureDim() const { return m_pSamples.C(); }
|
||||
|
||||
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
|
||||
inline int GetNumThreads() const { return m_iNumberOfThreads; }
|
||||
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
|
||||
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::BKT; }
|
||||
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
|
||||
|
||||
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
|
||||
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
|
||||
|
||||
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
|
||||
|
||||
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
|
||||
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
|
||||
|
||||
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
|
||||
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
|
||||
ErrorCode SearchIndex(QueryResult &p_query) const;
|
||||
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
|
||||
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
|
||||
|
||||
ErrorCode SetParameter(const char* p_param, const char* p_value);
|
||||
std::string GetParameter(const char* p_param) const;
|
||||
|
||||
private:
|
||||
ErrorCode RefineIndex(const std::string& p_folderPath);
|
||||
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
|
||||
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
|
||||
};
|
||||
} // namespace BKT
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_BKT_INDEX_H_
|
@ -1,36 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef DefineBKTParameter
|
||||
|
||||
// DefineBKTParameter(VarName, VarType, DefaultValue, RepresentStr)
|
||||
DefineBKTParameter(m_sBKTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
|
||||
DefineBKTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
|
||||
DefineBKTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
|
||||
|
||||
DefineBKTParameter(m_pTrees.m_iTreeNumber, int, 1L, "BKTNumber")
|
||||
DefineBKTParameter(m_pTrees.m_iBKTKmeansK, int, 32L, "BKTKmeansK")
|
||||
DefineBKTParameter(m_pTrees.m_iBKTLeafSize, int, 8L, "BKTLeafSize")
|
||||
DefineBKTParameter(m_pTrees.m_iSamples, int, 1000L, "Samples")
|
||||
|
||||
|
||||
DefineBKTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TpTreeNumber")
|
||||
DefineBKTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
|
||||
DefineBKTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTpTreeSplit")
|
||||
|
||||
DefineBKTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
|
||||
DefineBKTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
|
||||
DefineBKTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
|
||||
DefineBKTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
|
||||
DefineBKTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
|
||||
DefineBKTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
|
||||
|
||||
DefineBKTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
|
||||
DefineBKTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
|
||||
|
||||
DefineBKTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
|
||||
DefineBKTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
|
||||
DefineBKTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
|
||||
DefineBKTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
|
||||
|
||||
#endif
|
@ -1,166 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_CORE_COMMONDEFS_H_
|
||||
#define _SPTAG_CORE_COMMONDEFS_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
#include <cmath>
|
||||
|
||||
#ifndef _MSC_VER
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#define FolderSep '/'
|
||||
#define mkdir(a) mkdir(a, ACCESSPERMS)
|
||||
inline bool direxists(const char* path) {
|
||||
struct stat info;
|
||||
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR);
|
||||
}
|
||||
inline bool fileexists(const char* path) {
|
||||
struct stat info;
|
||||
return stat(path, &info) == 0 && (info.st_mode & S_IFDIR) == 0;
|
||||
}
|
||||
template <class T>
|
||||
inline T min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
template <class T>
|
||||
inline T max(T a, T b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
#ifndef _rotl
|
||||
#define _rotl(x, n) (((x) << (n)) | ((x) >> (32-(n))))
|
||||
#endif
|
||||
|
||||
#else
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <Windows.h>
|
||||
#include <Psapi.h>
|
||||
#define FolderSep '\\'
|
||||
#define mkdir(a) CreateDirectory(a, NULL)
|
||||
inline bool direxists(const char* path) {
|
||||
auto dwAttr = GetFileAttributes((LPCSTR)path);
|
||||
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY);
|
||||
}
|
||||
inline bool fileexists(const char* path) {
|
||||
auto dwAttr = GetFileAttributes((LPCSTR)path);
|
||||
return (dwAttr != INVALID_FILE_ATTRIBUTES) && (dwAttr & FILE_ATTRIBUTE_DIRECTORY) == 0;
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
typedef std::uint32_t SizeType;
|
||||
|
||||
const float MinDist = (std::numeric_limits<float>::min)();
|
||||
const float MaxDist = (std::numeric_limits<float>::max)();
|
||||
const float Epsilon = 0.000000001f;
|
||||
|
||||
class MyException : public std::exception
|
||||
{
|
||||
private:
|
||||
std::string Exp;
|
||||
public:
|
||||
MyException(std::string e) { Exp = e; }
|
||||
#ifdef _MSC_VER
|
||||
const char* what() const { return Exp.c_str(); }
|
||||
#else
|
||||
const char* what() const noexcept { return Exp.c_str(); }
|
||||
#endif
|
||||
};
|
||||
|
||||
// Type of number index.
|
||||
typedef std::int32_t IndexType;
|
||||
static_assert(std::is_integral<IndexType>::value, "IndexType must be integral type.");
|
||||
|
||||
|
||||
enum class ErrorCode : std::uint16_t
|
||||
{
|
||||
#define DefineErrorCode(Name, Value) Name = Value,
|
||||
#include "DefinitionList.h"
|
||||
#undef DefineErrorCode
|
||||
|
||||
Undefined
|
||||
};
|
||||
static_assert(static_cast<std::uint16_t>(ErrorCode::Undefined) != 0, "Empty ErrorCode!");
|
||||
|
||||
|
||||
enum class DistCalcMethod : std::uint8_t
|
||||
{
|
||||
#define DefineDistCalcMethod(Name) Name,
|
||||
#include "DefinitionList.h"
|
||||
#undef DefineDistCalcMethod
|
||||
|
||||
Undefined
|
||||
};
|
||||
static_assert(static_cast<std::uint8_t>(DistCalcMethod::Undefined) != 0, "Empty DistCalcMethod!");
|
||||
|
||||
|
||||
enum class VectorValueType : std::uint8_t
|
||||
{
|
||||
#define DefineVectorValueType(Name, Type) Name,
|
||||
#include "DefinitionList.h"
|
||||
#undef DefineVectorValueType
|
||||
|
||||
Undefined
|
||||
};
|
||||
static_assert(static_cast<std::uint8_t>(VectorValueType::Undefined) != 0, "Empty VectorValueType!");
|
||||
|
||||
|
||||
enum class IndexAlgoType : std::uint8_t
|
||||
{
|
||||
#define DefineIndexAlgo(Name) Name,
|
||||
#include "DefinitionList.h"
|
||||
#undef DefineIndexAlgo
|
||||
|
||||
Undefined
|
||||
};
|
||||
static_assert(static_cast<std::uint8_t>(IndexAlgoType::Undefined) != 0, "Empty IndexAlgoType!");
|
||||
|
||||
|
||||
template<typename T>
|
||||
constexpr VectorValueType GetEnumValueType()
|
||||
{
|
||||
return VectorValueType::Undefined;
|
||||
}
|
||||
|
||||
|
||||
#define DefineVectorValueType(Name, Type) \
|
||||
template<> \
|
||||
constexpr VectorValueType GetEnumValueType<Type>() \
|
||||
{ \
|
||||
return VectorValueType::Name; \
|
||||
} \
|
||||
|
||||
#include "DefinitionList.h"
|
||||
#undef DefineVectorValueType
|
||||
|
||||
|
||||
inline std::size_t GetValueTypeSize(VectorValueType p_valueType)
|
||||
{
|
||||
switch (p_valueType)
|
||||
{
|
||||
#define DefineVectorValueType(Name, Type) \
|
||||
case VectorValueType::Name: \
|
||||
return sizeof(Type); \
|
||||
|
||||
#include "DefinitionList.h"
|
||||
#undef DefineVectorValueType
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_CORE_COMMONDEFS_H_
|
@ -1,492 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_BKTREE_H_
|
||||
#define _SPTAG_COMMON_BKTREE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "../VectorIndex.h"
|
||||
|
||||
#include "CommonUtils.h"
|
||||
#include "QueryResultSet.h"
|
||||
#include "WorkSpace.h"
|
||||
|
||||
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
// node type for storing BKT
|
||||
struct BKTNode
|
||||
{
|
||||
int centerid;
|
||||
int childStart;
|
||||
int childEnd;
|
||||
|
||||
BKTNode(int cid = -1) : centerid(cid), childStart(-1), childEnd(-1) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct KmeansArgs {
|
||||
int _K;
|
||||
int _D;
|
||||
int _T;
|
||||
T* centers;
|
||||
int* counts;
|
||||
float* newCenters;
|
||||
int* newCounts;
|
||||
char* label;
|
||||
int* clusterIdx;
|
||||
float* clusterDist;
|
||||
T* newTCenters;
|
||||
|
||||
KmeansArgs(int k, int dim, int datasize, int threadnum) : _K(k), _D(dim), _T(threadnum) {
|
||||
centers = new T[k * dim];
|
||||
counts = new int[k];
|
||||
newCenters = new float[threadnum * k * dim];
|
||||
newCounts = new int[threadnum * k];
|
||||
label = new char[datasize];
|
||||
clusterIdx = new int[threadnum * k];
|
||||
clusterDist = new float[threadnum * k];
|
||||
newTCenters = new T[k * dim];
|
||||
}
|
||||
|
||||
~KmeansArgs() {
|
||||
delete[] centers;
|
||||
delete[] counts;
|
||||
delete[] newCenters;
|
||||
delete[] newCounts;
|
||||
delete[] label;
|
||||
delete[] clusterIdx;
|
||||
delete[] clusterDist;
|
||||
delete[] newTCenters;
|
||||
}
|
||||
|
||||
inline void ClearCounts() {
|
||||
memset(newCounts, 0, sizeof(int) * _T * _K);
|
||||
}
|
||||
|
||||
inline void ClearCenters() {
|
||||
memset(newCenters, 0, sizeof(float) * _T * _K * _D);
|
||||
}
|
||||
|
||||
inline void ClearDists(float dist) {
|
||||
for (int i = 0; i < _T * _K; i++) {
|
||||
clusterIdx[i] = -1;
|
||||
clusterDist[i] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
void Shuffle(std::vector<int>& indices, int first, int last) {
|
||||
int* pos = new int[_K];
|
||||
pos[0] = first;
|
||||
for (int k = 1; k < _K; k++) pos[k] = pos[k - 1] + newCounts[k - 1];
|
||||
|
||||
for (int k = 0; k < _K; k++) {
|
||||
if (newCounts[k] == 0) continue;
|
||||
int i = pos[k];
|
||||
while (newCounts[k] > 0) {
|
||||
int swapid = pos[(int)(label[i])] + newCounts[(int)(label[i])] - 1;
|
||||
newCounts[(int)(label[i])]--;
|
||||
std::swap(indices[i], indices[swapid]);
|
||||
std::swap(label[i], label[swapid]);
|
||||
}
|
||||
while (indices[i] != clusterIdx[k]) i++;
|
||||
std::swap(indices[i], indices[pos[k] + counts[k] - 1]);
|
||||
}
|
||||
delete[] pos;
|
||||
}
|
||||
};
|
||||
|
||||
class BKTree
|
||||
{
|
||||
public:
|
||||
BKTree(): m_iTreeNumber(1), m_iBKTKmeansK(32), m_iBKTLeafSize(8), m_iSamples(1000) {}
|
||||
|
||||
BKTree(BKTree& other): m_iTreeNumber(other.m_iTreeNumber),
|
||||
m_iBKTKmeansK(other.m_iBKTKmeansK),
|
||||
m_iBKTLeafSize(other.m_iBKTLeafSize),
|
||||
m_iSamples(other.m_iSamples) {}
|
||||
~BKTree() {}
|
||||
|
||||
inline const BKTNode& operator[](int index) const { return m_pTreeRoots[index]; }
|
||||
inline BKTNode& operator[](int index) { return m_pTreeRoots[index]; }
|
||||
|
||||
inline int size() const { return (int)m_pTreeRoots.size(); }
|
||||
|
||||
inline const std::unordered_map<int, int>& GetSampleMap() const { return m_pSampleCenterMap; }
|
||||
|
||||
template <typename T>
|
||||
void BuildTrees(VectorIndex* index, std::vector<int>* indices = nullptr)
|
||||
{
|
||||
struct BKTStackItem {
|
||||
int index, first, last;
|
||||
BKTStackItem(int index_, int first_, int last_) : index(index_), first(first_), last(last_) {}
|
||||
};
|
||||
std::stack<BKTStackItem> ss;
|
||||
|
||||
std::vector<int> localindices;
|
||||
if (indices == nullptr) {
|
||||
localindices.resize(index->GetNumSamples());
|
||||
for (int i = 0; i < index->GetNumSamples(); i++) localindices[i] = i;
|
||||
}
|
||||
else {
|
||||
localindices.assign(indices->begin(), indices->end());
|
||||
}
|
||||
KmeansArgs<T> args(m_iBKTKmeansK, index->GetFeatureDim(), (int)localindices.size(), omp_get_num_threads());
|
||||
|
||||
m_pSampleCenterMap.clear();
|
||||
for (char i = 0; i < m_iTreeNumber; i++)
|
||||
{
|
||||
std::random_shuffle(localindices.begin(), localindices.end());
|
||||
|
||||
m_pTreeStart.push_back((int)m_pTreeRoots.size());
|
||||
m_pTreeRoots.push_back(BKTNode((int)localindices.size()));
|
||||
std::cout << "Start to build BKTree " << i + 1 << std::endl;
|
||||
|
||||
ss.push(BKTStackItem(m_pTreeStart[i], 0, (int)localindices.size()));
|
||||
while (!ss.empty()) {
|
||||
BKTStackItem item = ss.top(); ss.pop();
|
||||
int newBKTid = (int)m_pTreeRoots.size();
|
||||
m_pTreeRoots[item.index].childStart = newBKTid;
|
||||
if (item.last - item.first <= m_iBKTLeafSize) {
|
||||
for (int j = item.first; j < item.last; j++) {
|
||||
m_pTreeRoots.push_back(BKTNode(localindices[j]));
|
||||
}
|
||||
}
|
||||
else { // clustering the data into BKTKmeansK clusters
|
||||
int numClusters = KmeansClustering(index, localindices, item.first, item.last, args);
|
||||
if (numClusters <= 1) {
|
||||
int end = min(item.last + 1, (int)localindices.size());
|
||||
std::sort(localindices.begin() + item.first, localindices.begin() + end);
|
||||
m_pTreeRoots[item.index].centerid = localindices[item.first];
|
||||
m_pTreeRoots[item.index].childStart = -m_pTreeRoots[item.index].childStart;
|
||||
for (int j = item.first + 1; j < end; j++) {
|
||||
m_pTreeRoots.push_back(BKTNode(localindices[j]));
|
||||
m_pSampleCenterMap[localindices[j]] = m_pTreeRoots[item.index].centerid;
|
||||
}
|
||||
m_pSampleCenterMap[-1 - m_pTreeRoots[item.index].centerid] = item.index;
|
||||
}
|
||||
else {
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++) {
|
||||
if (args.counts[k] == 0) continue;
|
||||
m_pTreeRoots.push_back(BKTNode(localindices[item.first + args.counts[k] - 1]));
|
||||
if (args.counts[k] > 1) ss.push(BKTStackItem(newBKTid++, item.first, item.first + args.counts[k] - 1));
|
||||
item.first += args.counts[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
m_pTreeRoots[item.index].childEnd = (int)m_pTreeRoots.size();
|
||||
}
|
||||
std::cout << i + 1 << " BKTree built, " << m_pTreeRoots.size() - m_pTreeStart[i] << " " << localindices.size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool SaveTrees(void **pKDTMemFile, int64_t &len) const
|
||||
{
|
||||
int treeNodeSize = (int)m_pTreeRoots.size();
|
||||
|
||||
size_t size = sizeof(int) +
|
||||
sizeof(int) * m_iTreeNumber +
|
||||
sizeof(int) +
|
||||
sizeof(BKTNode) * treeNodeSize;
|
||||
char *mem = (char*)malloc(size);
|
||||
if (mem == NULL) return false;
|
||||
|
||||
auto ptr = mem;
|
||||
*(int*)ptr = m_iTreeNumber;
|
||||
ptr += sizeof(int);
|
||||
|
||||
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
|
||||
ptr += sizeof(int) * m_iTreeNumber;
|
||||
|
||||
*(int*)ptr = treeNodeSize;
|
||||
ptr += sizeof(int);
|
||||
|
||||
memcpy(ptr, m_pTreeRoots.data(), sizeof(BKTNode) * treeNodeSize);
|
||||
*pKDTMemFile = mem;
|
||||
len = size;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SaveTrees(std::string sTreeFileName) const
|
||||
{
|
||||
std::cout << "Save BKT to " << sTreeFileName << std::endl;
|
||||
FILE *fp = fopen(sTreeFileName.c_str(), "wb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
fwrite(&m_iTreeNumber, sizeof(int), 1, fp);
|
||||
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
|
||||
int treeNodeSize = (int)m_pTreeRoots.size();
|
||||
fwrite(&treeNodeSize, sizeof(int), 1, fp);
|
||||
fwrite(m_pTreeRoots.data(), sizeof(BKTNode), treeNodeSize, fp);
|
||||
fclose(fp);
|
||||
std::cout << "Save BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadTrees(char* pBKTMemFile)
|
||||
{
|
||||
m_iTreeNumber = *((int*)pBKTMemFile);
|
||||
pBKTMemFile += sizeof(int);
|
||||
m_pTreeStart.resize(m_iTreeNumber);
|
||||
memcpy(m_pTreeStart.data(), pBKTMemFile, sizeof(int) * m_iTreeNumber);
|
||||
pBKTMemFile += sizeof(int)*m_iTreeNumber;
|
||||
|
||||
int treeNodeSize = *((int*)pBKTMemFile);
|
||||
pBKTMemFile += sizeof(int);
|
||||
m_pTreeRoots.resize(treeNodeSize);
|
||||
memcpy(m_pTreeRoots.data(), pBKTMemFile, sizeof(BKTNode) * treeNodeSize);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadTrees(std::string sTreeFileName)
|
||||
{
|
||||
std::cout << "Load BKT From " << sTreeFileName << std::endl;
|
||||
FILE *fp = fopen(sTreeFileName.c_str(), "rb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
fread(&m_iTreeNumber, sizeof(int), 1, fp);
|
||||
m_pTreeStart.resize(m_iTreeNumber);
|
||||
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
|
||||
|
||||
int treeNodeSize;
|
||||
fread(&treeNodeSize, sizeof(int), 1, fp);
|
||||
m_pTreeRoots.resize(treeNodeSize);
|
||||
fread(m_pTreeRoots.data(), sizeof(BKTNode), treeNodeSize, fp);
|
||||
fclose(fp);
|
||||
std::cout << "Load BKT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const
|
||||
{
|
||||
for (char i = 0; i < m_iTreeNumber; i++) {
|
||||
const BKTNode& node = m_pTreeRoots[m_pTreeStart[i]];
|
||||
if (node.childStart < 0) {
|
||||
p_space.m_SPTQueue.insert(COMMON::HeapCell(m_pTreeStart[i], p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(node.centerid))));
|
||||
}
|
||||
else {
|
||||
for (int begin = node.childStart; begin < node.childEnd; begin++) {
|
||||
int index = m_pTreeRoots[begin].centerid;
|
||||
p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(index))));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
|
||||
COMMON::WorkSpace &p_space, const int p_limits) const
|
||||
{
|
||||
do
|
||||
{
|
||||
COMMON::HeapCell bcell = p_space.m_SPTQueue.pop();
|
||||
const BKTNode& tnode = m_pTreeRoots[bcell.node];
|
||||
if (tnode.childStart < 0) {
|
||||
if (!p_space.CheckAndSet(tnode.centerid)) {
|
||||
p_space.m_iNumberOfCheckedLeaves++;
|
||||
p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
|
||||
}
|
||||
if (p_space.m_iNumberOfCheckedLeaves >= p_limits) break;
|
||||
}
|
||||
else {
|
||||
if (!p_space.CheckAndSet(tnode.centerid)) {
|
||||
p_space.m_NGQueue.insert(COMMON::HeapCell(tnode.centerid, bcell.distance));
|
||||
}
|
||||
for (int begin = tnode.childStart; begin < tnode.childEnd; begin++) {
|
||||
int index = m_pTreeRoots[begin].centerid;
|
||||
p_space.m_SPTQueue.insert(COMMON::HeapCell(begin, p_index->ComputeDistance((const void*)p_query.GetTarget(), p_index->GetSample(index))));
|
||||
}
|
||||
}
|
||||
} while (!p_space.m_SPTQueue.empty());
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
template <typename T>
|
||||
float KmeansAssign(VectorIndex* p_index,
|
||||
std::vector<int>& indices,
|
||||
const int first, const int last, KmeansArgs<T>& args, const bool updateCenters) const {
|
||||
float currDist = 0;
|
||||
int threads = omp_get_num_threads();
|
||||
float lambda = (updateCenters) ? COMMON::Utils::GetBase<T>() * COMMON::Utils::GetBase<T>() / (100.0f * (last - first)) : 0.0f;
|
||||
int subsize = (last - first - 1) / threads + 1;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int tid = 0; tid < threads; tid++)
|
||||
{
|
||||
int istart = first + tid * subsize;
|
||||
int iend = min(first + (tid + 1) * subsize, last);
|
||||
int *inewCounts = args.newCounts + tid * m_iBKTKmeansK;
|
||||
float *inewCenters = args.newCenters + tid * m_iBKTKmeansK * p_index->GetFeatureDim();
|
||||
int * iclusterIdx = args.clusterIdx + tid * m_iBKTKmeansK;
|
||||
float * iclusterDist = args.clusterDist + tid * m_iBKTKmeansK;
|
||||
float idist = 0;
|
||||
for (int i = istart; i < iend; i++) {
|
||||
int clusterid = 0;
|
||||
float smallestDist = MaxDist;
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++) {
|
||||
float dist = p_index->ComputeDistance(p_index->GetSample(indices[i]), (const void*)(args.centers + k*p_index->GetFeatureDim())) + lambda*args.counts[k];
|
||||
if (dist > -MaxDist && dist < smallestDist) {
|
||||
clusterid = k; smallestDist = dist;
|
||||
}
|
||||
}
|
||||
args.label[i] = clusterid;
|
||||
inewCounts[clusterid]++;
|
||||
idist += smallestDist;
|
||||
if (updateCenters) {
|
||||
const T* v = (const T*)p_index->GetSample(indices[i]);
|
||||
float* center = inewCenters + clusterid*p_index->GetFeatureDim();
|
||||
for (int j = 0; j < p_index->GetFeatureDim(); j++) center[j] += v[j];
|
||||
if (smallestDist > iclusterDist[clusterid]) {
|
||||
iclusterDist[clusterid] = smallestDist;
|
||||
iclusterIdx[clusterid] = indices[i];
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (smallestDist <= iclusterDist[clusterid]) {
|
||||
iclusterDist[clusterid] = smallestDist;
|
||||
iclusterIdx[clusterid] = indices[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
COMMON::Utils::atomic_float_add(&currDist, idist);
|
||||
}
|
||||
|
||||
for (int i = 1; i < threads; i++) {
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++)
|
||||
args.newCounts[k] += args.newCounts[i*m_iBKTKmeansK + k];
|
||||
}
|
||||
|
||||
if (updateCenters) {
|
||||
for (int i = 1; i < threads; i++) {
|
||||
float* currCenter = args.newCenters + i*m_iBKTKmeansK*p_index->GetFeatureDim();
|
||||
for (int j = 0; j < m_iBKTKmeansK * p_index->GetFeatureDim(); j++) args.newCenters[j] += currCenter[j];
|
||||
}
|
||||
|
||||
int maxcluster = 0;
|
||||
for (int k = 1; k < m_iBKTKmeansK; k++) if (args.newCounts[maxcluster] < args.newCounts[k]) maxcluster = k;
|
||||
|
||||
int maxid = maxcluster;
|
||||
for (int tid = 1; tid < threads; tid++) {
|
||||
if (args.clusterDist[maxid] < args.clusterDist[tid * m_iBKTKmeansK + maxcluster]) maxid = tid * m_iBKTKmeansK + maxcluster;
|
||||
}
|
||||
if (args.clusterIdx[maxid] < 0 || args.clusterIdx[maxid] >= p_index->GetNumSamples())
|
||||
std::cout << "first:" << first << " last:" << last << " maxcluster:" << maxcluster << "(" << args.newCounts[maxcluster] << ") Error maxid:" << maxid << " dist:" << args.clusterDist[maxid] << std::endl;
|
||||
maxid = args.clusterIdx[maxid];
|
||||
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++) {
|
||||
T* TCenter = args.newTCenters + k * p_index->GetFeatureDim();
|
||||
if (args.newCounts[k] == 0) {
|
||||
//int nextid = Utils::rand_int(last, first);
|
||||
//while (args.label[nextid] != maxcluster) nextid = Utils::rand_int(last, first);
|
||||
int nextid = maxid;
|
||||
std::memcpy(TCenter, p_index->GetSample(nextid), sizeof(T)*p_index->GetFeatureDim());
|
||||
}
|
||||
else {
|
||||
float* currCenters = args.newCenters + k * p_index->GetFeatureDim();
|
||||
for (int j = 0; j < p_index->GetFeatureDim(); j++) currCenters[j] /= args.newCounts[k];
|
||||
|
||||
if (p_index->GetDistCalcMethod() == DistCalcMethod::Cosine) {
|
||||
COMMON::Utils::Normalize(currCenters, p_index->GetFeatureDim(), COMMON::Utils::GetBase<T>());
|
||||
}
|
||||
for (int j = 0; j < p_index->GetFeatureDim(); j++) TCenter[j] = (T)(currCenters[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int i = 1; i < threads; i++) {
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++) {
|
||||
if (args.clusterIdx[i*m_iBKTKmeansK + k] != -1 && args.clusterDist[i*m_iBKTKmeansK + k] <= args.clusterDist[k]) {
|
||||
args.clusterDist[k] = args.clusterDist[i*m_iBKTKmeansK + k];
|
||||
args.clusterIdx[k] = args.clusterIdx[i*m_iBKTKmeansK + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return currDist;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int KmeansClustering(VectorIndex* p_index,
|
||||
std::vector<int>& indices, const int first, const int last, KmeansArgs<T>& args) const {
|
||||
int iterLimit = 100;
|
||||
|
||||
int batchEnd = min(first + m_iSamples, last);
|
||||
float currDiff, currDist, minClusterDist = MaxDist;
|
||||
for (int numKmeans = 0; numKmeans < 3; numKmeans++) {
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++) {
|
||||
int randid = COMMON::Utils::rand_int(last, first);
|
||||
std::memcpy(args.centers + k*p_index->GetFeatureDim(), p_index->GetSample(indices[randid]), sizeof(T)*p_index->GetFeatureDim());
|
||||
}
|
||||
args.ClearCounts();
|
||||
currDist = KmeansAssign(p_index, indices, first, batchEnd, args, false);
|
||||
if (currDist < minClusterDist) {
|
||||
minClusterDist = currDist;
|
||||
memcpy(args.newTCenters, args.centers, sizeof(T)*m_iBKTKmeansK*p_index->GetFeatureDim());
|
||||
memcpy(args.counts, args.newCounts, sizeof(int) * m_iBKTKmeansK);
|
||||
}
|
||||
}
|
||||
|
||||
minClusterDist = MaxDist;
|
||||
int noImprovement = 0;
|
||||
for (int iter = 0; iter < iterLimit; iter++) {
|
||||
std::memcpy(args.centers, args.newTCenters, sizeof(T)*m_iBKTKmeansK*p_index->GetFeatureDim());
|
||||
std::random_shuffle(indices.begin() + first, indices.begin() + last);
|
||||
|
||||
args.ClearCenters();
|
||||
args.ClearCounts();
|
||||
args.ClearDists(-MaxDist);
|
||||
currDist = KmeansAssign(p_index, indices, first, batchEnd, args, true);
|
||||
memcpy(args.counts, args.newCounts, sizeof(int)*m_iBKTKmeansK);
|
||||
|
||||
currDiff = 0;
|
||||
for (int k = 0; k < m_iBKTKmeansK; k++) {
|
||||
currDiff += p_index->ComputeDistance((const void*)(args.centers + k*p_index->GetFeatureDim()), (const void*)(args.newTCenters + k*p_index->GetFeatureDim()));
|
||||
}
|
||||
|
||||
if (currDist < minClusterDist) {
|
||||
noImprovement = 0;
|
||||
minClusterDist = currDist;
|
||||
}
|
||||
else {
|
||||
noImprovement++;
|
||||
}
|
||||
if (currDiff < 1e-3 || noImprovement >= 5) break;
|
||||
}
|
||||
|
||||
args.ClearCounts();
|
||||
args.ClearDists(MaxDist);
|
||||
currDist = KmeansAssign(p_index, indices, first, last, args, false);
|
||||
memcpy(args.counts, args.newCounts, sizeof(int)*m_iBKTKmeansK);
|
||||
|
||||
int numClusters = 0;
|
||||
for (int i = 0; i < m_iBKTKmeansK; i++) if (args.counts[i] > 0) numClusters++;
|
||||
|
||||
if (numClusters <= 1) {
|
||||
//if (last - first > 1) std::cout << "large cluster:" << last - first << " dist:" << currDist << std::endl;
|
||||
return numClusters;
|
||||
}
|
||||
args.Shuffle(indices, first, last);
|
||||
return numClusters;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> m_pTreeStart;
|
||||
std::vector<BKTNode> m_pTreeRoots;
|
||||
std::unordered_map<int, int> m_pSampleCenterMap;
|
||||
|
||||
public:
|
||||
int m_iTreeNumber, m_iBKTKmeansK, m_iBKTLeafSize, m_iSamples;
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif
|
@ -1,178 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_COMMONUTILS_H_
|
||||
#define _SPTAG_COMMON_COMMONUTILS_H_
|
||||
|
||||
#include "../Common.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <exception>
|
||||
#include <algorithm>
|
||||
|
||||
#include <time.h>
|
||||
#include <omp.h>
|
||||
#include <string.h>
|
||||
|
||||
#define PREFETCH
|
||||
|
||||
#ifndef _MSC_VER
|
||||
#include <stdio.h>
|
||||
#include <unistd.h>
|
||||
#include <sys/resource.h>
|
||||
#include <cstring>
|
||||
|
||||
#define InterlockedCompareExchange(a,b,c) __sync_val_compare_and_swap(a, c, b)
|
||||
#define Sleep(a) usleep(a * 1000)
|
||||
#define strtok_s(a, b, c) strtok_r(a, b, c)
|
||||
#endif
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
class Utils {
|
||||
public:
|
||||
static int rand_int(int high = RAND_MAX, int low = 0) // Generates a random int value.
|
||||
{
|
||||
return low + (int)(float(high - low)*(std::rand() / (RAND_MAX + 1.0)));
|
||||
}
|
||||
|
||||
static inline float atomic_float_add(volatile float* ptr, const float operand)
|
||||
{
|
||||
union {
|
||||
volatile long iOld;
|
||||
float fOld;
|
||||
};
|
||||
union {
|
||||
long iNew;
|
||||
float fNew;
|
||||
};
|
||||
|
||||
while (true) {
|
||||
iOld = *(volatile long *)ptr;
|
||||
fNew = fOld + operand;
|
||||
if (InterlockedCompareExchange((long *)ptr, iNew, iOld) == iOld) {
|
||||
return fNew;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static double GetVector(char* cstr, const char* sep, std::vector<float>& arr, int& NumDim) {
|
||||
char* current;
|
||||
char* context = NULL;
|
||||
|
||||
int i = 0;
|
||||
double sum = 0;
|
||||
arr.clear();
|
||||
current = strtok_s(cstr, sep, &context);
|
||||
while (current != NULL && (i < NumDim || NumDim < 0)) {
|
||||
try {
|
||||
float val = (float)atof(current);
|
||||
arr.push_back(val);
|
||||
}
|
||||
catch (std::exception e) {
|
||||
std::cout << "Exception:" << e.what() << std::endl;
|
||||
return -2;
|
||||
}
|
||||
|
||||
sum += arr[i] * arr[i];
|
||||
current = strtok_s(NULL, sep, &context);
|
||||
i++;
|
||||
}
|
||||
|
||||
if (NumDim < 0) NumDim = i;
|
||||
if (i < NumDim) return -2;
|
||||
return std::sqrt(sum);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void Normalize(T* arr, int col, int base) {
|
||||
double vecLen = 0;
|
||||
for (int j = 0; j < col; j++) {
|
||||
double val = arr[j];
|
||||
vecLen += val * val;
|
||||
}
|
||||
vecLen = std::sqrt(vecLen);
|
||||
if (vecLen < 1e-6) {
|
||||
T val = (T)(1.0 / std::sqrt((double)col) * base);
|
||||
for (int j = 0; j < col; j++) arr[j] = val;
|
||||
}
|
||||
else {
|
||||
for (int j = 0; j < col; j++) arr[j] = (T)(arr[j] / vecLen * base);
|
||||
}
|
||||
}
|
||||
|
||||
static size_t ProcessLine(std::string& currentLine, std::vector<float>& arr, int& D, int base, DistCalcMethod distCalcMethod) {
|
||||
size_t index;
|
||||
double vecLen;
|
||||
if (currentLine.length() == 0 || (index = currentLine.find_last_of("\t")) == std::string::npos || (vecLen = GetVector(const_cast<char*>(currentLine.c_str() + index + 1), "|", arr, D)) < -1) {
|
||||
std::cout << "Parse vector error: " + currentLine << std::endl;
|
||||
//throw MyException("Error in parsing data " + currentLine);
|
||||
return -1;
|
||||
}
|
||||
if (distCalcMethod == DistCalcMethod::Cosine) {
|
||||
Normalize(arr.data(), D, base);
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void PrepareQuerys(std::ifstream& inStream, std::vector<std::string>& qString, std::vector<std::vector<T>>& Query, int& NumQuery, int& NumDim, DistCalcMethod distCalcMethod, int base) {
|
||||
std::string currentLine;
|
||||
std::vector<float> arr;
|
||||
int i = 0;
|
||||
size_t index;
|
||||
while ((NumQuery < 0 || i < NumQuery) && !inStream.eof()) {
|
||||
std::getline(inStream, currentLine);
|
||||
if (currentLine.length() <= 1 || (index = ProcessLine(currentLine, arr, NumDim, base, distCalcMethod)) < 0) {
|
||||
continue;
|
||||
}
|
||||
qString.push_back(currentLine.substr(0, index));
|
||||
if (Query.size() < i + 1) Query.push_back(std::vector<T>(NumDim, 0));
|
||||
|
||||
for (int j = 0; j < NumDim; j++) Query[i][j] = (T)arr[j];
|
||||
i++;
|
||||
}
|
||||
NumQuery = i;
|
||||
std::cout << "Load data: (" << NumQuery << ", " << NumDim << ")" << std::endl;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline int GetBase() {
|
||||
if (GetEnumValueType<T>() != VectorValueType::Float) {
|
||||
return (int)(std::numeric_limits<T>::max)();
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static inline void AddNeighbor(int idx, float dist, int *neighbors, float *dists, int size)
|
||||
{
|
||||
size--;
|
||||
if (dist < dists[size] || (dist == dists[size] && idx < neighbors[size]))
|
||||
{
|
||||
int nb;
|
||||
for (nb = 0; nb <= size && neighbors[nb] != idx; nb++);
|
||||
|
||||
if (nb > size)
|
||||
{
|
||||
nb = size;
|
||||
while (nb > 0 && (dist < dists[nb - 1] || (dist == dists[nb - 1] && idx < neighbors[nb - 1])))
|
||||
{
|
||||
dists[nb] = dists[nb - 1];
|
||||
neighbors[nb] = neighbors[nb - 1];
|
||||
nb--;
|
||||
}
|
||||
dists[nb] = dist;
|
||||
neighbors[nb] = idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_COMMONUTILS_H_
|
@ -1,290 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_DATAUTILS_H_
|
||||
#define _SPTAG_COMMON_DATAUTILS_H_
|
||||
|
||||
#include <sys/stat.h>
|
||||
#include <atomic>
|
||||
#include "CommonUtils.h"
|
||||
#include "../../Helper/CommonHelper.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
const int bufsize = 1024 * 1024 * 1024;
|
||||
|
||||
class DataUtils {
|
||||
public:
|
||||
template <typename T>
|
||||
static void ProcessTSVData(int id, int threadbase, std::uint64_t blocksize,
|
||||
std::string filename, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
|
||||
std::atomic_int& numSamples, int& D, DistCalcMethod distCalcMethod) {
|
||||
std::ifstream inputStream(filename);
|
||||
if (!inputStream.is_open()) {
|
||||
std::cerr << "unable to open file " + filename << std::endl;
|
||||
throw MyException("unable to open file " + filename);
|
||||
exit(1);
|
||||
}
|
||||
std::ofstream outputStream, metaStream_out, metaStream_index;
|
||||
outputStream.open(outfile + std::to_string(id + threadbase), std::ofstream::binary);
|
||||
metaStream_out.open(outmetafile + std::to_string(id + threadbase), std::ofstream::binary);
|
||||
metaStream_index.open(outmetaindexfile + std::to_string(id + threadbase), std::ofstream::binary);
|
||||
if (!outputStream.is_open() || !metaStream_out.is_open() || !metaStream_index.is_open()) {
|
||||
std::cerr << "unable to open output file " << outfile << " " << outmetafile << " " << outmetaindexfile << std::endl;
|
||||
throw MyException("unable to open output files");
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::vector<float> arr;
|
||||
std::vector<T> sample;
|
||||
|
||||
int base = 1;
|
||||
if (distCalcMethod == DistCalcMethod::Cosine) {
|
||||
base = Utils::GetBase<T>();
|
||||
}
|
||||
std::uint64_t writepos = 0;
|
||||
int sampleSize = 0;
|
||||
std::uint64_t totalread = 0;
|
||||
std::streamoff startpos = id * blocksize;
|
||||
|
||||
#ifndef _MSC_VER
|
||||
int enter_size = 1;
|
||||
#else
|
||||
int enter_size = 1;
|
||||
#endif
|
||||
std::string currentLine;
|
||||
size_t index;
|
||||
inputStream.seekg(startpos, std::ifstream::beg);
|
||||
if (id != 0) {
|
||||
std::getline(inputStream, currentLine);
|
||||
totalread += currentLine.length() + enter_size;
|
||||
}
|
||||
std::cout << "Begin thread " << id << " begin at:" << (startpos + totalread) << std::endl;
|
||||
while (!inputStream.eof() && totalread <= blocksize) {
|
||||
std::getline(inputStream, currentLine);
|
||||
if (currentLine.length() <= enter_size || (index = Utils::ProcessLine(currentLine, arr, D, base, distCalcMethod)) < 0) {
|
||||
totalread += currentLine.length() + enter_size;
|
||||
continue;
|
||||
}
|
||||
sample.resize(D);
|
||||
for (int j = 0; j < D; j++) sample[j] = (T)arr[j];
|
||||
|
||||
outputStream.write((char *)(sample.data()), sizeof(T)*D);
|
||||
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
|
||||
metaStream_out.write(currentLine.c_str(), index);
|
||||
|
||||
writepos += index;
|
||||
sampleSize += 1;
|
||||
totalread += currentLine.length() + enter_size;
|
||||
}
|
||||
metaStream_index.write((char *)&writepos, sizeof(std::uint64_t));
|
||||
metaStream_index.write((char *)&sampleSize, sizeof(int));
|
||||
inputStream.close();
|
||||
outputStream.close();
|
||||
metaStream_out.close();
|
||||
metaStream_index.close();
|
||||
|
||||
numSamples.fetch_add(sampleSize);
|
||||
|
||||
std::cout << "Finish Thread[" << id << ", " << sampleSize << "] at:" << (startpos + totalread) << std::endl;
|
||||
}
|
||||
|
||||
static void MergeData(int threadbase, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
|
||||
std::atomic_int& numSamples, int D) {
|
||||
std::ifstream inputStream;
|
||||
std::ofstream outputStream;
|
||||
char * buf = new char[bufsize];
|
||||
std::uint64_t * offsets;
|
||||
int partSamples;
|
||||
int metaSamples = 0;
|
||||
std::uint64_t lastoff = 0;
|
||||
|
||||
outputStream.open(outfile, std::ofstream::binary);
|
||||
outputStream.write((char *)&numSamples, sizeof(int));
|
||||
outputStream.write((char *)&D, sizeof(int));
|
||||
for (int i = 0; i < threadbase; i++) {
|
||||
std::string file = outfile + std::to_string(i);
|
||||
inputStream.open(file, std::ifstream::binary);
|
||||
while (!inputStream.eof()) {
|
||||
inputStream.read(buf, bufsize);
|
||||
outputStream.write(buf, inputStream.gcount());
|
||||
}
|
||||
inputStream.close();
|
||||
remove(file.c_str());
|
||||
}
|
||||
outputStream.close();
|
||||
|
||||
outputStream.open(outmetafile, std::ofstream::binary);
|
||||
for (int i = 0; i < threadbase; i++) {
|
||||
std::string file = outmetafile + std::to_string(i);
|
||||
inputStream.open(file, std::ifstream::binary);
|
||||
while (!inputStream.eof()) {
|
||||
inputStream.read(buf, bufsize);
|
||||
outputStream.write(buf, inputStream.gcount());
|
||||
}
|
||||
inputStream.close();
|
||||
remove(file.c_str());
|
||||
}
|
||||
outputStream.close();
|
||||
delete[] buf;
|
||||
|
||||
outputStream.open(outmetaindexfile, std::ofstream::binary);
|
||||
outputStream.write((char *)&numSamples, sizeof(int));
|
||||
for (int i = 0; i < threadbase; i++) {
|
||||
std::string file = outmetaindexfile + std::to_string(i);
|
||||
inputStream.open(file, std::ifstream::binary);
|
||||
|
||||
inputStream.seekg(-((long long)sizeof(int)), inputStream.end);
|
||||
inputStream.read((char *)&partSamples, sizeof(int));
|
||||
offsets = new std::uint64_t[partSamples + 1];
|
||||
|
||||
inputStream.seekg(0, inputStream.beg);
|
||||
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1));
|
||||
inputStream.close();
|
||||
remove(file.c_str());
|
||||
|
||||
for (int j = 0; j < partSamples + 1; j++)
|
||||
offsets[j] += lastoff;
|
||||
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples);
|
||||
|
||||
lastoff = offsets[partSamples];
|
||||
metaSamples += partSamples;
|
||||
delete[] offsets;
|
||||
}
|
||||
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
|
||||
outputStream.close();
|
||||
|
||||
std::cout << "numSamples:" << numSamples << " metaSamples:" << metaSamples << " D:" << D << std::endl;
|
||||
}
|
||||
|
||||
static bool MergeIndex(const std::string& p_vectorfile1, const std::string& p_metafile1, const std::string& p_metaindexfile1,
|
||||
const std::string& p_vectorfile2, const std::string& p_metafile2, const std::string& p_metaindexfile2) {
|
||||
std::ifstream inputStream1, inputStream2;
|
||||
std::ofstream outputStream;
|
||||
char * buf = new char[bufsize];
|
||||
int R1, R2, C1, C2;
|
||||
|
||||
#define MergeVector(inputStream, vectorFile, R, C) \
|
||||
inputStream.open(vectorFile, std::ifstream::binary); \
|
||||
if (!inputStream.is_open()) { \
|
||||
std::cout << "Cannot open vector file: " << vectorFile <<"!" << std::endl; \
|
||||
return false; \
|
||||
} \
|
||||
inputStream.read((char *)&(R), sizeof(int)); \
|
||||
inputStream.read((char *)&(C), sizeof(int)); \
|
||||
|
||||
MergeVector(inputStream1, p_vectorfile1, R1, C1)
|
||||
MergeVector(inputStream2, p_vectorfile2, R2, C2)
|
||||
#undef MergeVector
|
||||
if (C1 != C2) {
|
||||
inputStream1.close(); inputStream2.close();
|
||||
std::cout << "Vector dimensions are not the same!" << std::endl;
|
||||
return false;
|
||||
}
|
||||
R1 += R2;
|
||||
outputStream.open(p_vectorfile1 + "_tmp", std::ofstream::binary);
|
||||
outputStream.write((char *)&R1, sizeof(int));
|
||||
outputStream.write((char *)&C1, sizeof(int));
|
||||
while (!inputStream1.eof()) {
|
||||
inputStream1.read(buf, bufsize);
|
||||
outputStream.write(buf, inputStream1.gcount());
|
||||
}
|
||||
while (!inputStream2.eof()) {
|
||||
inputStream2.read(buf, bufsize);
|
||||
outputStream.write(buf, inputStream2.gcount());
|
||||
}
|
||||
inputStream1.close(); inputStream2.close();
|
||||
outputStream.close();
|
||||
|
||||
if (p_metafile1 != "" && p_metafile2 != "") {
|
||||
outputStream.open(p_metafile1 + "_tmp", std::ofstream::binary);
|
||||
#define MergeMeta(inputStream, metaFile) \
|
||||
inputStream.open(metaFile, std::ifstream::binary); \
|
||||
if (!inputStream.is_open()) { \
|
||||
std::cout << "Cannot open meta file: " << metaFile << "!" << std::endl; \
|
||||
return false; \
|
||||
} \
|
||||
while (!inputStream.eof()) { \
|
||||
inputStream.read(buf, bufsize); \
|
||||
outputStream.write(buf, inputStream.gcount()); \
|
||||
} \
|
||||
inputStream.close(); \
|
||||
|
||||
MergeMeta(inputStream1, p_metafile1)
|
||||
MergeMeta(inputStream2, p_metafile2)
|
||||
#undef MergeMeta
|
||||
outputStream.close();
|
||||
delete[] buf;
|
||||
|
||||
|
||||
std::uint64_t * offsets;
|
||||
int partSamples;
|
||||
std::uint64_t lastoff = 0;
|
||||
outputStream.open(p_metaindexfile1 + "_tmp", std::ofstream::binary);
|
||||
outputStream.write((char *)&R1, sizeof(int));
|
||||
#define MergeMetaIndex(inputStream, metaIndexFile) \
|
||||
inputStream.open(metaIndexFile, std::ifstream::binary); \
|
||||
if (!inputStream.is_open()) { \
|
||||
std::cout << "Cannot open meta index file: " << metaIndexFile << "!" << std::endl; \
|
||||
return false; \
|
||||
} \
|
||||
inputStream.read((char *)&partSamples, sizeof(int)); \
|
||||
offsets = new std::uint64_t[partSamples + 1]; \
|
||||
inputStream.read((char *)offsets, sizeof(std::uint64_t)*(partSamples + 1)); \
|
||||
inputStream.close(); \
|
||||
for (int j = 0; j < partSamples + 1; j++) offsets[j] += lastoff; \
|
||||
outputStream.write((char *)offsets, sizeof(std::uint64_t)*partSamples); \
|
||||
lastoff = offsets[partSamples]; \
|
||||
delete[] offsets; \
|
||||
|
||||
MergeMetaIndex(inputStream1, p_metaindexfile1)
|
||||
MergeMetaIndex(inputStream2, p_metaindexfile2)
|
||||
#undef MergeMetaIndex
|
||||
outputStream.write((char *)&lastoff, sizeof(std::uint64_t));
|
||||
outputStream.close();
|
||||
|
||||
rename((p_metafile1 + "_tmp").c_str(), p_metafile1.c_str());
|
||||
rename((p_metaindexfile1 + "_tmp").c_str(), p_metaindexfile1.c_str());
|
||||
}
|
||||
rename((p_vectorfile1 + "_tmp").c_str(), p_vectorfile1.c_str());
|
||||
|
||||
std::cout << "Merged -> numSamples:" << R1 << " D:" << C1 << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void ParseData(std::string filenames, std::string outfile, std::string outmetafile, std::string outmetaindexfile,
|
||||
int threadnum, DistCalcMethod distCalcMethod) {
|
||||
omp_set_num_threads(threadnum);
|
||||
|
||||
std::atomic_int numSamples = { 0 };
|
||||
int D = -1;
|
||||
|
||||
int threadbase = 0;
|
||||
std::vector<std::string> inputFileNames = Helper::StrUtils::SplitString(filenames, ",");
|
||||
for (std::string inputFileName : inputFileNames)
|
||||
{
|
||||
#ifndef _MSC_VER
|
||||
struct stat stat_buf;
|
||||
stat(inputFileName.c_str(), &stat_buf);
|
||||
#else
|
||||
struct _stat64 stat_buf;
|
||||
int res = _stat64(inputFileName.c_str(), &stat_buf);
|
||||
#endif
|
||||
std::uint64_t blocksize = (stat_buf.st_size + threadnum - 1) / threadnum;
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < threadnum; i++) {
|
||||
ProcessTSVData<T>(i, threadbase, blocksize, inputFileName, outfile, outmetafile, outmetaindexfile, numSamples, D, distCalcMethod);
|
||||
}
|
||||
threadbase += threadnum;
|
||||
}
|
||||
MergeData(threadbase, outfile, outmetafile, outmetaindexfile, numSamples, D);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_DATAUTILS_H_
|
@ -1,216 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_DATASET_H_
|
||||
#define _SPTAG_COMMON_DATASET_H_
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#if defined(_MSC_VER) || defined(__INTEL_COMPILER)
|
||||
#include <malloc.h>
|
||||
#else
|
||||
#include <mm_malloc.h>
|
||||
#endif // defined(__GNUC__)
|
||||
|
||||
#define ALIGN 32
|
||||
|
||||
#define aligned_malloc(a, b) _mm_malloc(a, b)
|
||||
#define aligned_free(a) _mm_free(a)
|
||||
|
||||
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
// structure to save Data and Graph
|
||||
template <typename T>
|
||||
class Dataset
|
||||
{
|
||||
private:
|
||||
int rows;
|
||||
int cols;
|
||||
bool ownData = false;
|
||||
T* data = nullptr;
|
||||
std::vector<T> dataIncremental;
|
||||
|
||||
public:
|
||||
Dataset(): rows(0), cols(1) {}
|
||||
Dataset(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
|
||||
{
|
||||
Initialize(rows_, cols_, data_, transferOnwership_);
|
||||
}
|
||||
~Dataset()
|
||||
{
|
||||
if (ownData) aligned_free(data);
|
||||
}
|
||||
void Initialize(int rows_, int cols_, T* data_ = nullptr, bool transferOnwership_ = true)
|
||||
{
|
||||
rows = rows_;
|
||||
cols = cols_;
|
||||
data = data_;
|
||||
if (data_ == nullptr || !transferOnwership_)
|
||||
{
|
||||
ownData = true;
|
||||
data = (T*)aligned_malloc(sizeof(T) * rows * cols, ALIGN);
|
||||
if (data_ != nullptr) memcpy(data, data_, rows * cols * sizeof(T));
|
||||
else std::memset(data, -1, rows * cols * sizeof(T));
|
||||
}
|
||||
}
|
||||
void SetR(int R_)
|
||||
{
|
||||
if (R_ >= rows)
|
||||
dataIncremental.resize((R_ - rows) * cols);
|
||||
else
|
||||
{
|
||||
rows = R_;
|
||||
dataIncremental.clear();
|
||||
}
|
||||
}
|
||||
inline int R() const { return (int)(rows + dataIncremental.size() / cols); }
|
||||
inline int C() const { return cols; }
|
||||
T* operator[](int index)
|
||||
{
|
||||
if (index >= rows) {
|
||||
return dataIncremental.data() + (size_t)(index - rows)*cols;
|
||||
}
|
||||
return data + (size_t)index*cols;
|
||||
}
|
||||
|
||||
const T* operator[](int index) const
|
||||
{
|
||||
if (index >= rows) {
|
||||
return dataIncremental.data() + (size_t)(index - rows)*cols;
|
||||
}
|
||||
return data + (size_t)index*cols;
|
||||
}
|
||||
|
||||
void AddBatch(const T* pData, int num)
|
||||
{
|
||||
dataIncremental.insert(dataIncremental.end(), pData, pData + num*cols);
|
||||
}
|
||||
|
||||
void AddBatch(int num)
|
||||
{
|
||||
dataIncremental.insert(dataIncremental.end(), (size_t)num*cols, T(-1));
|
||||
}
|
||||
|
||||
bool Save(std::string sDataPointsFileName)
|
||||
{
|
||||
std::cout << "Save Data To " << sDataPointsFileName << std::endl;
|
||||
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
int CR = R();
|
||||
fwrite(&CR, sizeof(int), 1, fp);
|
||||
fwrite(&cols, sizeof(int), 1, fp);
|
||||
|
||||
T* ptr = data;
|
||||
int toWrite = rows;
|
||||
while (toWrite > 0)
|
||||
{
|
||||
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
|
||||
ptr += write * cols;
|
||||
toWrite -= (int)write;
|
||||
}
|
||||
ptr = dataIncremental.data();
|
||||
toWrite = CR - rows;
|
||||
while (toWrite > 0)
|
||||
{
|
||||
size_t write = fwrite(ptr, sizeof(T) * cols, toWrite, fp);
|
||||
ptr += write * cols;
|
||||
toWrite -= (int)write;
|
||||
}
|
||||
fclose(fp);
|
||||
|
||||
std::cout << "Save Data (" << CR << ", " << cols << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Save(void **pDataPointsMemFile, int64_t &len)
|
||||
{
|
||||
size_t size = sizeof(int) + sizeof(int) + sizeof(T) * R() *cols;
|
||||
char *mem = (char*)malloc(size);
|
||||
if (mem == NULL) return false;
|
||||
|
||||
int CR = R();
|
||||
|
||||
auto header = (int*)mem;
|
||||
header[0] = CR;
|
||||
header[1] = cols;
|
||||
auto body = &mem[8];
|
||||
|
||||
memcpy(body, data, sizeof(T) * cols * rows);
|
||||
body += sizeof(T) * cols * rows;
|
||||
memcpy(body, dataIncremental.data(), sizeof(T) * cols * (CR - rows));
|
||||
body += sizeof(T) * cols * (CR - rows);
|
||||
|
||||
*pDataPointsMemFile = mem;
|
||||
len = size;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Load(std::string sDataPointsFileName)
|
||||
{
|
||||
std::cout << "Load Data From " << sDataPointsFileName << std::endl;
|
||||
FILE * fp = fopen(sDataPointsFileName.c_str(), "rb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
int R, C;
|
||||
fread(&R, sizeof(int), 1, fp);
|
||||
fread(&C, sizeof(int), 1, fp);
|
||||
|
||||
Initialize(R, C);
|
||||
T* ptr = data;
|
||||
while (R > 0) {
|
||||
size_t read = fread(ptr, sizeof(T) * C, R, fp);
|
||||
ptr += read * C;
|
||||
R -= (int)read;
|
||||
}
|
||||
fclose(fp);
|
||||
std::cout << "Load Data (" << rows << ", " << cols << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Functions for loading models from memory mapped files
|
||||
bool Load(char* pDataPointsMemFile)
|
||||
{
|
||||
int R, C;
|
||||
R = *((int*)pDataPointsMemFile);
|
||||
pDataPointsMemFile += sizeof(int);
|
||||
|
||||
C = *((int*)pDataPointsMemFile);
|
||||
pDataPointsMemFile += sizeof(int);
|
||||
|
||||
Initialize(R, C, (T*)pDataPointsMemFile);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Refine(const std::vector<int>& indices, std::string sDataPointsFileName)
|
||||
{
|
||||
std::cout << "Save Refine Data To " << sDataPointsFileName << std::endl;
|
||||
FILE * fp = fopen(sDataPointsFileName.c_str(), "wb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
int R = (int)(indices.size());
|
||||
fwrite(&R, sizeof(int), 1, fp);
|
||||
fwrite(&cols, sizeof(int), 1, fp);
|
||||
|
||||
// write point one by one in case for cache miss
|
||||
for (int i = 0; i < R; i++) {
|
||||
if (indices[i] < rows)
|
||||
fwrite(data + (size_t)indices[i] * cols, sizeof(T) * cols, 1, fp);
|
||||
else
|
||||
fwrite(dataIncremental.data() + (size_t)(indices[i] - rows) * cols, sizeof(T) * cols, 1, fp);
|
||||
}
|
||||
fclose(fp);
|
||||
|
||||
std::cout << "Save Refine Data (" << R << ", " << cols << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_DATASET_H_
|
@ -1,610 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_DISTANCEUTILS_H_
|
||||
#define _SPTAG_COMMON_DISTANCEUTILS_H_
|
||||
|
||||
#include <immintrin.h>
|
||||
#include <functional>
|
||||
|
||||
#include "CommonUtils.h"
|
||||
|
||||
#define SSE
|
||||
|
||||
#ifndef _MSC_VER
|
||||
#define DIFF128 diff128
|
||||
#define DIFF256 diff256
|
||||
#else
|
||||
#define DIFF128 diff128.m128_f32
|
||||
#define DIFF256 diff256.m256_f32
|
||||
#endif
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
class DistanceUtils
|
||||
{
|
||||
public:
|
||||
static inline __m128 _mm_mul_epi8(__m128i X, __m128i Y)
|
||||
{
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
|
||||
__m128i sign_x = _mm_cmplt_epi8(X, zero);
|
||||
__m128i sign_y = _mm_cmplt_epi8(Y, zero);
|
||||
|
||||
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
|
||||
__m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
|
||||
__m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
|
||||
|
||||
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_sqdf_epi8(__m128i X, __m128i Y)
|
||||
{
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
|
||||
__m128i sign_x = _mm_cmplt_epi8(X, zero);
|
||||
__m128i sign_y = _mm_cmplt_epi8(Y, zero);
|
||||
|
||||
__m128i xlo = _mm_unpacklo_epi8(X, sign_x);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, sign_x);
|
||||
__m128i ylo = _mm_unpacklo_epi8(Y, sign_y);
|
||||
__m128i yhi = _mm_unpackhi_epi8(Y, sign_y);
|
||||
|
||||
__m128i dlo = _mm_sub_epi16(xlo, ylo);
|
||||
__m128i dhi = _mm_sub_epi16(xhi, yhi);
|
||||
|
||||
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi)));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_mul_epu8(__m128i X, __m128i Y)
|
||||
{
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
|
||||
__m128i xlo = _mm_unpacklo_epi8(X, zero);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, zero);
|
||||
__m128i ylo = _mm_unpacklo_epi8(Y, zero);
|
||||
__m128i yhi = _mm_unpackhi_epi8(Y, zero);
|
||||
|
||||
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(xlo, ylo), _mm_madd_epi16(xhi, yhi)));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_sqdf_epu8(__m128i X, __m128i Y)
|
||||
{
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
|
||||
__m128i xlo = _mm_unpacklo_epi8(X, zero);
|
||||
__m128i xhi = _mm_unpackhi_epi8(X, zero);
|
||||
__m128i ylo = _mm_unpacklo_epi8(Y, zero);
|
||||
__m128i yhi = _mm_unpackhi_epi8(Y, zero);
|
||||
|
||||
__m128i dlo = _mm_sub_epi16(xlo, ylo);
|
||||
__m128i dhi = _mm_sub_epi16(xhi, yhi);
|
||||
|
||||
return _mm_cvtepi32_ps(_mm_add_epi32(_mm_madd_epi16(dlo, dlo), _mm_madd_epi16(dhi, dhi)));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_mul_epi16(__m128i X, __m128i Y)
|
||||
{
|
||||
return _mm_cvtepi32_ps(_mm_madd_epi16(X, Y));
|
||||
}
|
||||
|
||||
static inline __m128 _mm_sqdf_epi16(__m128i X, __m128i Y)
|
||||
{
|
||||
__m128i zero = _mm_setzero_si128();
|
||||
|
||||
__m128i sign_x = _mm_cmplt_epi16(X, zero);
|
||||
__m128i sign_y = _mm_cmplt_epi16(Y, zero);
|
||||
|
||||
__m128i xlo = _mm_unpacklo_epi16(X, sign_x);
|
||||
__m128i xhi = _mm_unpackhi_epi16(X, sign_x);
|
||||
__m128i ylo = _mm_unpacklo_epi16(Y, sign_y);
|
||||
__m128i yhi = _mm_unpackhi_epi16(Y, sign_y);
|
||||
|
||||
__m128 dlo = _mm_cvtepi32_ps(_mm_sub_epi32(xlo, ylo));
|
||||
__m128 dhi = _mm_cvtepi32_ps(_mm_sub_epi32(xhi, yhi));
|
||||
|
||||
return _mm_add_ps(_mm_mul_ps(dlo, dlo), _mm_mul_ps(dhi, dhi));
|
||||
}
|
||||
static inline __m128 _mm_sqdf_ps(__m128 X, __m128 Y)
|
||||
{
|
||||
__m128 d = _mm_sub_ps(X, Y);
|
||||
return _mm_mul_ps(d, d);
|
||||
}
|
||||
#if defined(AVX)
|
||||
static inline __m256 _mm256_mul_epi8(__m256i X, __m256i Y)
|
||||
{
|
||||
__m256i zero = _mm256_setzero_si256();
|
||||
|
||||
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
|
||||
__m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
|
||||
|
||||
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
|
||||
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
|
||||
__m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
|
||||
__m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
|
||||
|
||||
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi)));
|
||||
}
|
||||
static inline __m256 _mm256_sqdf_epi8(__m256i X, __m256i Y)
|
||||
{
|
||||
__m256i zero = _mm256_setzero_si256();
|
||||
|
||||
__m256i sign_x = _mm256_cmpgt_epi8(zero, X);
|
||||
__m256i sign_y = _mm256_cmpgt_epi8(zero, Y);
|
||||
|
||||
__m256i xlo = _mm256_unpacklo_epi8(X, sign_x);
|
||||
__m256i xhi = _mm256_unpackhi_epi8(X, sign_x);
|
||||
__m256i ylo = _mm256_unpacklo_epi8(Y, sign_y);
|
||||
__m256i yhi = _mm256_unpackhi_epi8(Y, sign_y);
|
||||
|
||||
__m256i dlo = _mm256_sub_epi16(xlo, ylo);
|
||||
__m256i dhi = _mm256_sub_epi16(xhi, yhi);
|
||||
|
||||
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi)));
|
||||
}
|
||||
static inline __m256 _mm256_mul_epu8(__m256i X, __m256i Y)
|
||||
{
|
||||
__m256i zero = _mm256_setzero_si256();
|
||||
|
||||
__m256i xlo = _mm256_unpacklo_epi8(X, zero);
|
||||
__m256i xhi = _mm256_unpackhi_epi8(X, zero);
|
||||
__m256i ylo = _mm256_unpacklo_epi8(Y, zero);
|
||||
__m256i yhi = _mm256_unpackhi_epi8(Y, zero);
|
||||
|
||||
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(xlo, ylo), _mm256_madd_epi16(xhi, yhi)));
|
||||
}
|
||||
static inline __m256 _mm256_sqdf_epu8(__m256i X, __m256i Y)
|
||||
{
|
||||
__m256i zero = _mm256_setzero_si256();
|
||||
|
||||
__m256i xlo = _mm256_unpacklo_epi8(X, zero);
|
||||
__m256i xhi = _mm256_unpackhi_epi8(X, zero);
|
||||
__m256i ylo = _mm256_unpacklo_epi8(Y, zero);
|
||||
__m256i yhi = _mm256_unpackhi_epi8(Y, zero);
|
||||
|
||||
__m256i dlo = _mm256_sub_epi16(xlo, ylo);
|
||||
__m256i dhi = _mm256_sub_epi16(xhi, yhi);
|
||||
|
||||
return _mm256_cvtepi32_ps(_mm256_add_epi32(_mm256_madd_epi16(dlo, dlo), _mm256_madd_epi16(dhi, dhi)));
|
||||
}
|
||||
static inline __m256 _mm256_mul_epi16(__m256i X, __m256i Y)
|
||||
{
|
||||
return _mm256_cvtepi32_ps(_mm256_madd_epi16(X, Y));
|
||||
}
|
||||
static inline __m256 _mm256_sqdf_epi16(__m256i X, __m256i Y)
|
||||
{
|
||||
__m256i zero = _mm256_setzero_si256();
|
||||
|
||||
__m256i sign_x = _mm256_cmpgt_epi16(zero, X);
|
||||
__m256i sign_y = _mm256_cmpgt_epi16(zero, Y);
|
||||
|
||||
__m256i xlo = _mm256_unpacklo_epi16(X, sign_x);
|
||||
__m256i xhi = _mm256_unpackhi_epi16(X, sign_x);
|
||||
__m256i ylo = _mm256_unpacklo_epi16(Y, sign_y);
|
||||
__m256i yhi = _mm256_unpackhi_epi16(Y, sign_y);
|
||||
|
||||
__m256 dlo = _mm256_cvtepi32_ps(_mm256_sub_epi32(xlo, ylo));
|
||||
__m256 dhi = _mm256_cvtepi32_ps(_mm256_sub_epi32(xhi, yhi));
|
||||
|
||||
return _mm256_add_ps(_mm256_mul_ps(dlo, dlo), _mm256_mul_ps(dhi, dhi));
|
||||
}
|
||||
static inline __m256 _mm256_sqdf_ps(__m256 X, __m256 Y)
|
||||
{
|
||||
__m256 d = _mm256_sub_ps(X, Y);
|
||||
return _mm256_mul_ps(d, d);
|
||||
}
|
||||
#endif
|
||||
/*
|
||||
template<typename T>
|
||||
static float ComputeL2Distance(const T *pX, const T *pY, int length)
|
||||
{
|
||||
float diff = 0;
|
||||
const T* pEnd1 = pX + length;
|
||||
while (pX < pEnd1) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
*/
|
||||
#define REPEAT(type, ctype, delta, load, exec, acc, result) \
|
||||
{ \
|
||||
type c1 = load((ctype *)(pX)); \
|
||||
type c2 = load((ctype *)(pY)); \
|
||||
pX += delta; pY += delta; \
|
||||
result = acc(result, exec(c1, c2)); \
|
||||
} \
|
||||
|
||||
static float ComputeL2Distance(const std::int8_t *pX, const std::int8_t *pY, int length)
|
||||
{
|
||||
const std::int8_t* pEnd32 = pX + ((length >> 5) << 5);
|
||||
const std::int8_t* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const std::int8_t* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const std::int8_t* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epi8, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epi8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
#endif
|
||||
while (pX < pEnd4) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
while (pX < pEnd1) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
|
||||
static float ComputeL2Distance(const std::uint8_t *pX, const std::uint8_t *pY, int length)
|
||||
{
|
||||
const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5);
|
||||
const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const std::uint8_t* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_sqdf_epu8, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_sqdf_epu8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
#endif
|
||||
while (pX < pEnd4) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
while (pX < pEnd1) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
|
||||
static float ComputeL2Distance(const std::int16_t *pX, const std::int16_t *pY, int length)
|
||||
{
|
||||
const std::int16_t* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const std::int16_t* pEnd8 = pX + ((length >> 3) << 3);
|
||||
const std::int16_t* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const std::int16_t* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd8) {
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_sqdf_epi16, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd8) {
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_sqdf_epi16, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
#endif
|
||||
while (pX < pEnd4) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
|
||||
while (pX < pEnd1) {
|
||||
float c1 = ((float)(*pX++) - (float)(*pY++)); diff += c1 * c1;
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
|
||||
static float ComputeL2Distance(const float *pX, const float *pY, int length)
|
||||
{
|
||||
const float* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const float* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const float* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd16)
|
||||
{
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd16)
|
||||
{
|
||||
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256)
|
||||
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_sqdf_ps, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_sqdf_ps, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
while (pX < pEnd4) {
|
||||
float c1 = (*pX++) - (*pY++); diff += c1 * c1;
|
||||
c1 = (*pX++) - (*pY++); diff += c1 * c1;
|
||||
c1 = (*pX++) - (*pY++); diff += c1 * c1;
|
||||
c1 = (*pX++) - (*pY++); diff += c1 * c1;
|
||||
}
|
||||
#endif
|
||||
while (pX < pEnd1) {
|
||||
float c1 = (*pX++) - (*pY++); diff += c1 * c1;
|
||||
}
|
||||
return diff;
|
||||
}
|
||||
/*
|
||||
template<typename T>
|
||||
static float ComputeCosineDistance(const T *pX, const T *pY, int length) {
|
||||
float diff = 0;
|
||||
const T* pEnd1 = pX + length;
|
||||
while (pX < pEnd1) diff += (*pX++) * (*pY++);
|
||||
return 1 - diff;
|
||||
}
|
||||
*/
|
||||
static float ComputeCosineDistance(const std::int8_t *pX, const std::int8_t *pY, int length) {
|
||||
const std::int8_t* pEnd32 = pX + ((length >> 5) << 5);
|
||||
const std::int8_t* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const std::int8_t* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const std::int8_t* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epi8, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epi8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
#endif
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
}
|
||||
while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++));
|
||||
return 16129 - diff;
|
||||
}
|
||||
|
||||
static float ComputeCosineDistance(const std::uint8_t *pX, const std::uint8_t *pY, int length) {
|
||||
const std::uint8_t* pEnd32 = pX + ((length >> 5) << 5);
|
||||
const std::uint8_t* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const std::uint8_t* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const std::uint8_t* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd32) {
|
||||
REPEAT(__m256i, __m256i, 32, _mm256_loadu_si256, _mm256_mul_epu8, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 16, _mm_loadu_si128, _mm_mul_epu8, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
#endif
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
}
|
||||
while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++));
|
||||
return 65025 - diff;
|
||||
}
|
||||
|
||||
static float ComputeCosineDistance(const std::int16_t *pX, const std::int16_t *pY, int length) {
|
||||
const std::int16_t* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const std::int16_t* pEnd8 = pX + ((length >> 3) << 3);
|
||||
const std::int16_t* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const std::int16_t* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd8) {
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd16) {
|
||||
REPEAT(__m256i, __m256i, 16, _mm256_loadu_si256, _mm256_mul_epi16, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd8) {
|
||||
REPEAT(__m128i, __m128i, 8, _mm_loadu_si128, _mm_mul_epi16, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
#endif
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
float c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
c1 = ((float)(*pX++) * (float)(*pY++)); diff += c1;
|
||||
}
|
||||
|
||||
while (pX < pEnd1) diff += ((float)(*pX++) * (float)(*pY++));
|
||||
return 1073676289 - diff;
|
||||
}
|
||||
|
||||
static float ComputeCosineDistance(const float *pX, const float *pY, int length) {
|
||||
const float* pEnd16 = pX + ((length >> 4) << 4);
|
||||
const float* pEnd4 = pX + ((length >> 2) << 2);
|
||||
const float* pEnd1 = pX + length;
|
||||
#if defined(SSE)
|
||||
__m128 diff128 = _mm_setzero_ps();
|
||||
while (pX < pEnd16)
|
||||
{
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
|
||||
}
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
|
||||
#elif defined(AVX)
|
||||
__m256 diff256 = _mm256_setzero_ps();
|
||||
while (pX < pEnd16)
|
||||
{
|
||||
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256)
|
||||
REPEAT(__m256, const float, 8, _mm256_loadu_ps, _mm256_mul_ps, _mm256_add_ps, diff256)
|
||||
}
|
||||
__m128 diff128 = _mm_add_ps(_mm256_castps256_ps128(diff256), _mm256_extractf128_ps(diff256, 1));
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
REPEAT(__m128, const float, 4, _mm_loadu_ps, _mm_mul_ps, _mm_add_ps, diff128)
|
||||
}
|
||||
float diff = DIFF128[0] + DIFF128[1] + DIFF128[2] + DIFF128[3];
|
||||
#else
|
||||
float diff = 0;
|
||||
while (pX < pEnd4)
|
||||
{
|
||||
float c1 = (*pX++) * (*pY++); diff += c1;
|
||||
c1 = (*pX++) * (*pY++); diff += c1;
|
||||
c1 = (*pX++) * (*pY++); diff += c1;
|
||||
c1 = (*pX++) * (*pY++); diff += c1;
|
||||
}
|
||||
#endif
|
||||
while (pX < pEnd1) diff += (*pX++) * (*pY++);
|
||||
return 1 - diff;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static inline float ComputeDistance(const T *p1, const T *p2, int length, SPTAG::DistCalcMethod distCalcMethod)
|
||||
{
|
||||
if (distCalcMethod == SPTAG::DistCalcMethod::L2)
|
||||
return ComputeL2Distance(p1, p2, length);
|
||||
|
||||
return ComputeCosineDistance(p1, p2, length);
|
||||
}
|
||||
|
||||
static inline float ConvertCosineSimilarityToDistance(float cs)
|
||||
{
|
||||
// Cosine similarity is in [-1, 1], the higher the value, the closer are the two vectors.
|
||||
// However, the tree is built and searched based on "distance" between two vectors, that's >=0. The smaller the value, the closer are the two vectors.
|
||||
// So we do a linear conversion from a cosine similarity to a distance value.
|
||||
return 1 - cs; //[1, 3]
|
||||
}
|
||||
|
||||
static inline float ConvertDistanceBackToCosineSimilarity(float d)
|
||||
{
|
||||
return 1 - d;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template<typename T>
|
||||
float (*DistanceCalcSelector(SPTAG::DistCalcMethod p_method)) (const T*, const T*, int)
|
||||
{
|
||||
switch (p_method)
|
||||
{
|
||||
case SPTAG::DistCalcMethod::Cosine:
|
||||
return &(DistanceUtils::ComputeCosineDistance);
|
||||
|
||||
case SPTAG::DistCalcMethod::L2:
|
||||
return &(DistanceUtils::ComputeL2Distance);
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_DISTANCEUTILS_H_
|
@ -1,51 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_FINEGRAINEDLOCK_H_
|
||||
#define _SPTAG_COMMON_FINEGRAINEDLOCK_H_
|
||||
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
class FineGrainedLock {
|
||||
public:
|
||||
FineGrainedLock() {}
|
||||
~FineGrainedLock() {
|
||||
for (int i = 0; i < locks.size(); i++)
|
||||
locks[i].reset();
|
||||
locks.clear();
|
||||
}
|
||||
|
||||
void resize(int n) {
|
||||
int current = (int)locks.size();
|
||||
if (current <= n) {
|
||||
locks.resize(n);
|
||||
for (int i = current; i < n; i++)
|
||||
locks[i].reset(new std::mutex);
|
||||
}
|
||||
else {
|
||||
for (int i = n; i < current; i++)
|
||||
locks[i].reset();
|
||||
locks.resize(n);
|
||||
}
|
||||
}
|
||||
|
||||
std::mutex& operator[](int idx) {
|
||||
return *locks[idx];
|
||||
}
|
||||
|
||||
const std::mutex& operator[](int idx) const {
|
||||
return *locks[idx];
|
||||
}
|
||||
private:
|
||||
std::vector<std::shared_ptr<std::mutex>> locks;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_FINEGRAINEDLOCK_H_
|
@ -1,105 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_HEAP_H_
|
||||
#define _SPTAG_COMMON_HEAP_H_
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
|
||||
// priority queue
|
||||
template <typename T>
|
||||
class Heap {
|
||||
public:
|
||||
Heap() : heap(nullptr), length(0), count(0) {}
|
||||
|
||||
Heap(int size) { Resize(size); }
|
||||
|
||||
void Resize(int size)
|
||||
{
|
||||
length = size;
|
||||
heap.reset(new T[length + 1]); // heap uses 1-based indexing
|
||||
count = 0;
|
||||
lastlevel = int(pow(2.0, floor(log2(size))));
|
||||
}
|
||||
~Heap() {}
|
||||
inline int size() { return count; }
|
||||
inline bool empty() { return count == 0; }
|
||||
inline void clear() { count = 0; }
|
||||
inline T& Top() { if (count == 0) return heap[0]; else return heap[1]; }
|
||||
|
||||
// Insert a new element in the heap.
|
||||
void insert(T value)
|
||||
{
|
||||
/* If heap is full, then return without adding this element. */
|
||||
int loc;
|
||||
if (count == length) {
|
||||
int maxi = lastlevel;
|
||||
for (int i = lastlevel + 1; i <= length; i++)
|
||||
if (heap[maxi] < heap[i]) maxi = i;
|
||||
if (value > heap[maxi]) return;
|
||||
loc = maxi;
|
||||
}
|
||||
else {
|
||||
loc = ++(count); /* Remember 1-based indexing. */
|
||||
}
|
||||
/* Keep moving parents down until a place is found for this node. */
|
||||
int par = (loc >> 1); /* Location of parent. */
|
||||
while (par > 0 && value < heap[par]) {
|
||||
heap[loc] = heap[par]; /* Move parent down to loc. */
|
||||
loc = par;
|
||||
par >>= 1;
|
||||
}
|
||||
/* Insert the element at the determined location. */
|
||||
heap[loc] = value;
|
||||
}
|
||||
// Returns the node of minimum value from the heap (top of the heap).
|
||||
bool pop(T& value)
|
||||
{
|
||||
if (count == 0) return false;
|
||||
/* Switch first node with last. */
|
||||
value = heap[1];
|
||||
std::swap(heap[1], heap[count]);
|
||||
count--;
|
||||
heapify(); /* Move new node 1 to right position. */
|
||||
return true; /* Return old last node. */
|
||||
}
|
||||
T& pop()
|
||||
{
|
||||
if (count == 0) return heap[0];
|
||||
/* Switch first node with last. */
|
||||
std::swap(heap[1], heap[count]);
|
||||
count--;
|
||||
heapify(); /* Move new node 1 to right position. */
|
||||
return heap[count + 1]; /* Return old last node. */
|
||||
}
|
||||
private:
|
||||
// Storage array for the heap.
|
||||
// Type T must be comparable.
|
||||
std::unique_ptr<T[]> heap;
|
||||
int length;
|
||||
int count; // Number of element in the heap
|
||||
int lastlevel;
|
||||
// Reorganizes the heap (a parent is smaller than its children) starting with a node.
|
||||
|
||||
void heapify()
|
||||
{
|
||||
int parent = 1, next = 2;
|
||||
while (next < count) {
|
||||
if (heap[next] > heap[next + 1]) next++;
|
||||
if (heap[next] < heap[parent]) {
|
||||
std::swap(heap[parent], heap[next]);
|
||||
parent = next;
|
||||
next <<= 1;
|
||||
}
|
||||
else break;
|
||||
}
|
||||
if (next == count && heap[next] < heap[parent]) std::swap(heap[parent], heap[next]);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_HEAP_H_
|
@ -1,358 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_KDTREE_H_
|
||||
#define _SPTAG_COMMON_KDTREE_H_
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "../VectorIndex.h"
|
||||
|
||||
#include "CommonUtils.h"
|
||||
#include "QueryResultSet.h"
|
||||
#include "WorkSpace.h"
|
||||
|
||||
#pragma warning(disable:4996) // 'fopen': This function or variable may be unsafe. Consider using fopen_s instead. To disable deprecation, use _CRT_SECURE_NO_WARNINGS. See online help for details.
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
// node type for storing KDT
|
||||
struct KDTNode
|
||||
{
|
||||
int left;
|
||||
int right;
|
||||
short split_dim;
|
||||
float split_value;
|
||||
};
|
||||
|
||||
class KDTree
|
||||
{
|
||||
public:
|
||||
KDTree() : m_iTreeNumber(2), m_numTopDimensionKDTSplit(5), m_iSamples(1000) {}
|
||||
|
||||
KDTree(KDTree& other) : m_iTreeNumber(other.m_iTreeNumber),
|
||||
m_numTopDimensionKDTSplit(other.m_numTopDimensionKDTSplit),
|
||||
m_iSamples(other.m_iSamples) {}
|
||||
~KDTree() {}
|
||||
|
||||
inline const KDTNode& operator[](int index) const { return m_pTreeRoots[index]; }
|
||||
inline KDTNode& operator[](int index) { return m_pTreeRoots[index]; }
|
||||
|
||||
inline int size() const { return (int)m_pTreeRoots.size(); }
|
||||
|
||||
template <typename T>
|
||||
void BuildTrees(VectorIndex* p_index, std::vector<int>* indices = nullptr)
|
||||
{
|
||||
std::vector<int> localindices;
|
||||
if (indices == nullptr) {
|
||||
localindices.resize(p_index->GetNumSamples());
|
||||
for (int i = 0; i < p_index->GetNumSamples(); i++) localindices[i] = i;
|
||||
}
|
||||
else {
|
||||
localindices.assign(indices->begin(), indices->end());
|
||||
}
|
||||
|
||||
m_pTreeRoots.resize(m_iTreeNumber * localindices.size());
|
||||
m_pTreeStart.resize(m_iTreeNumber, 0);
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < m_iTreeNumber; i++)
|
||||
{
|
||||
Sleep(i * 100); std::srand(clock());
|
||||
|
||||
std::vector<int> pindices(localindices.begin(), localindices.end());
|
||||
std::random_shuffle(pindices.begin(), pindices.end());
|
||||
|
||||
m_pTreeStart[i] = i * (int)pindices.size();
|
||||
std::cout << "Start to build KDTree " << i + 1 << std::endl;
|
||||
int iTreeSize = m_pTreeStart[i];
|
||||
DivideTree<T>(p_index, pindices, 0, (int)pindices.size() - 1, m_pTreeStart[i], iTreeSize);
|
||||
std::cout << i + 1 << " KDTree built, " << iTreeSize - m_pTreeStart[i] << " " << pindices.size() << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
bool SaveTrees(void **pKDTMemFile, int64_t &len) const
|
||||
{
|
||||
int treeNodeSize = (int)m_pTreeRoots.size();
|
||||
|
||||
size_t size = sizeof(int) +
|
||||
sizeof(int) * m_iTreeNumber +
|
||||
sizeof(int) +
|
||||
sizeof(KDTNode) * treeNodeSize;
|
||||
char *mem = (char*)malloc(size);
|
||||
if (mem == NULL) return false;
|
||||
|
||||
auto ptr = mem;
|
||||
*(int*)ptr = m_iTreeNumber;
|
||||
ptr += sizeof(int);
|
||||
|
||||
memcpy(ptr, m_pTreeStart.data(), sizeof(int) * m_iTreeNumber);
|
||||
ptr += sizeof(int) * m_iTreeNumber;
|
||||
|
||||
*(int*)ptr = treeNodeSize;
|
||||
ptr += sizeof(int);
|
||||
|
||||
memcpy(ptr, m_pTreeRoots.data(), sizeof(KDTNode) * treeNodeSize);
|
||||
*pKDTMemFile = mem;
|
||||
len = size;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SaveTrees(std::string sTreeFileName) const
|
||||
{
|
||||
std::cout << "Save KDT to " << sTreeFileName << std::endl;
|
||||
FILE *fp = fopen(sTreeFileName.c_str(), "wb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
fwrite(&m_iTreeNumber, sizeof(int), 1, fp);
|
||||
fwrite(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
|
||||
int treeNodeSize = (int)m_pTreeRoots.size();
|
||||
fwrite(&treeNodeSize, sizeof(int), 1, fp);
|
||||
fwrite(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
|
||||
fclose(fp);
|
||||
std::cout << "Save KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadTrees(char* pKDTMemFile)
|
||||
{
|
||||
m_iTreeNumber = *((int*)pKDTMemFile);
|
||||
pKDTMemFile += sizeof(int);
|
||||
m_pTreeStart.resize(m_iTreeNumber);
|
||||
memcpy(m_pTreeStart.data(), pKDTMemFile, sizeof(int) * m_iTreeNumber);
|
||||
pKDTMemFile += sizeof(int)*m_iTreeNumber;
|
||||
|
||||
int treeNodeSize = *((int*)pKDTMemFile);
|
||||
pKDTMemFile += sizeof(int);
|
||||
m_pTreeRoots.resize(treeNodeSize);
|
||||
memcpy(m_pTreeRoots.data(), pKDTMemFile, sizeof(KDTNode) * treeNodeSize);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadTrees(std::string sTreeFileName)
|
||||
{
|
||||
std::cout << "Load KDT From " << sTreeFileName << std::endl;
|
||||
FILE *fp = fopen(sTreeFileName.c_str(), "rb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
fread(&m_iTreeNumber, sizeof(int), 1, fp);
|
||||
m_pTreeStart.resize(m_iTreeNumber);
|
||||
fread(m_pTreeStart.data(), sizeof(int), m_iTreeNumber, fp);
|
||||
|
||||
int treeNodeSize;
|
||||
fread(&treeNodeSize, sizeof(int), 1, fp);
|
||||
m_pTreeRoots.resize(treeNodeSize);
|
||||
fread(m_pTreeRoots.data(), sizeof(KDTNode), treeNodeSize, fp);
|
||||
fclose(fp);
|
||||
std::cout << "Load KDT (" << m_iTreeNumber << "," << treeNodeSize << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InitSearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
|
||||
{
|
||||
for (char i = 0; i < m_iTreeNumber; i++) {
|
||||
KDTSearch(p_index, p_query, p_space, m_pTreeStart[i], true, 0);
|
||||
}
|
||||
|
||||
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
|
||||
{
|
||||
auto& tcell = p_space.m_SPTQueue.pop();
|
||||
if (p_query.worstDist() < tcell.distance) break;
|
||||
KDTSearch(p_index, p_query, p_space, tcell.node, true, tcell.distance);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SearchTrees(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const int p_limits) const
|
||||
{
|
||||
while (!p_space.m_SPTQueue.empty() && p_space.m_iNumberOfCheckedLeaves < p_limits)
|
||||
{
|
||||
auto& tcell = p_space.m_SPTQueue.pop();
|
||||
KDTSearch(p_index, p_query, p_space, tcell.node, false, tcell.distance);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
template <typename T>
|
||||
void KDTSearch(const VectorIndex* p_index, const COMMON::QueryResultSet<T> &p_query,
|
||||
COMMON::WorkSpace& p_space, const int node, const bool isInit, const float distBound) const {
|
||||
if (node < 0)
|
||||
{
|
||||
int index = -node - 1;
|
||||
if (index >= p_index->GetNumSamples()) return;
|
||||
#ifdef PREFETCH
|
||||
const char* data = (const char *)(p_index->GetSample(index));
|
||||
_mm_prefetch(data, _MM_HINT_T0);
|
||||
_mm_prefetch(data + 64, _MM_HINT_T0);
|
||||
#endif
|
||||
if (p_space.CheckAndSet(index)) return;
|
||||
|
||||
++p_space.m_iNumberOfTreeCheckedLeaves;
|
||||
++p_space.m_iNumberOfCheckedLeaves;
|
||||
p_space.m_NGQueue.insert(COMMON::HeapCell(index, p_index->ComputeDistance((const void*)p_query.GetTarget(), (const void*)data)));
|
||||
return;
|
||||
}
|
||||
|
||||
auto& tnode = m_pTreeRoots[node];
|
||||
|
||||
float diff = (p_query.GetTarget())[tnode.split_dim] - tnode.split_value;
|
||||
float distanceBound = distBound + diff * diff;
|
||||
int otherChild, bestChild;
|
||||
if (diff < 0)
|
||||
{
|
||||
bestChild = tnode.left;
|
||||
otherChild = tnode.right;
|
||||
}
|
||||
else
|
||||
{
|
||||
otherChild = tnode.left;
|
||||
bestChild = tnode.right;
|
||||
}
|
||||
|
||||
if (!isInit || distanceBound < p_query.worstDist())
|
||||
{
|
||||
p_space.m_SPTQueue.insert(COMMON::HeapCell(otherChild, distanceBound));
|
||||
}
|
||||
KDTSearch(p_index, p_query, p_space, bestChild, isInit, distBound);
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void DivideTree(VectorIndex* p_index, std::vector<int>& indices, int first, int last,
|
||||
int index, int &iTreeSize) {
|
||||
ChooseDivision<T>(p_index, m_pTreeRoots[index], indices, first, last);
|
||||
int i = Subdivide<T>(p_index, m_pTreeRoots[index], indices, first, last);
|
||||
if (i - 1 <= first)
|
||||
{
|
||||
m_pTreeRoots[index].left = -indices[first] - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
iTreeSize++;
|
||||
m_pTreeRoots[index].left = iTreeSize;
|
||||
DivideTree<T>(p_index, indices, first, i - 1, iTreeSize, iTreeSize);
|
||||
}
|
||||
if (last == i)
|
||||
{
|
||||
m_pTreeRoots[index].right = -indices[last] - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
iTreeSize++;
|
||||
m_pTreeRoots[index].right = iTreeSize;
|
||||
DivideTree<T>(p_index, indices, i, last, iTreeSize, iTreeSize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ChooseDivision(VectorIndex* p_index, KDTNode& node, const std::vector<int>& indices, const int first, const int last)
|
||||
{
|
||||
std::vector<float> meanValues(p_index->GetFeatureDim(), 0);
|
||||
std::vector<float> varianceValues(p_index->GetFeatureDim(), 0);
|
||||
int end = min(first + m_iSamples, last);
|
||||
int count = end - first + 1;
|
||||
// calculate the mean of each dimension
|
||||
for (int j = first; j <= end; j++)
|
||||
{
|
||||
const T* v = (const T*)p_index->GetSample(indices[j]);
|
||||
for (int k = 0; k < p_index->GetFeatureDim(); k++)
|
||||
{
|
||||
meanValues[k] += v[k];
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < p_index->GetFeatureDim(); k++)
|
||||
{
|
||||
meanValues[k] /= count;
|
||||
}
|
||||
// calculate the variance of each dimension
|
||||
for (int j = first; j <= end; j++)
|
||||
{
|
||||
const T* v = (const T*)p_index->GetSample(indices[j]);
|
||||
for (int k = 0; k < p_index->GetFeatureDim(); k++)
|
||||
{
|
||||
float dist = v[k] - meanValues[k];
|
||||
varianceValues[k] += dist*dist;
|
||||
}
|
||||
}
|
||||
// choose the split dimension as one of the dimension inside TOP_DIM maximum variance
|
||||
node.split_dim = SelectDivisionDimension(varianceValues);
|
||||
// determine the threshold
|
||||
node.split_value = meanValues[node.split_dim];
|
||||
}
|
||||
|
||||
int SelectDivisionDimension(const std::vector<float>& varianceValues) const
|
||||
{
|
||||
// Record the top maximum variances
|
||||
std::vector<int> topind(m_numTopDimensionKDTSplit);
|
||||
int num = 0;
|
||||
// order the variances
|
||||
for (int i = 0; i < varianceValues.size(); i++)
|
||||
{
|
||||
if (num < m_numTopDimensionKDTSplit || varianceValues[i] > varianceValues[topind[num - 1]])
|
||||
{
|
||||
if (num < m_numTopDimensionKDTSplit)
|
||||
{
|
||||
topind[num++] = i;
|
||||
}
|
||||
else
|
||||
{
|
||||
topind[num - 1] = i;
|
||||
}
|
||||
int j = num - 1;
|
||||
// order the TOP_DIM variances
|
||||
while (j > 0 && varianceValues[topind[j]] > varianceValues[topind[j - 1]])
|
||||
{
|
||||
std::swap(topind[j], topind[j - 1]);
|
||||
j--;
|
||||
}
|
||||
}
|
||||
}
|
||||
// randomly choose a dimension from TOP_DIM
|
||||
return topind[COMMON::Utils::rand_int(num)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int Subdivide(VectorIndex* p_index, const KDTNode& node, std::vector<int>& indices, const int first, const int last) const
|
||||
{
|
||||
int i = first;
|
||||
int j = last;
|
||||
// decide which child one point belongs
|
||||
while (i <= j)
|
||||
{
|
||||
int ind = indices[i];
|
||||
const T* v = (const T*)p_index->GetSample(ind);
|
||||
float val = v[node.split_dim];
|
||||
if (val < node.split_value)
|
||||
{
|
||||
i++;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::swap(indices[i], indices[j]);
|
||||
j--;
|
||||
}
|
||||
}
|
||||
// if all the points in the node are equal,equally split the node into 2
|
||||
if ((i == first) || (i == last + 1))
|
||||
{
|
||||
i = (first + last + 1) / 2;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int> m_pTreeStart;
|
||||
std::vector<KDTNode> m_pTreeRoots;
|
||||
|
||||
public:
|
||||
int m_iTreeNumber, m_numTopDimensionKDTSplit, m_iSamples;
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif
|
@ -1,436 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_NG_H_
|
||||
#define _SPTAG_COMMON_NG_H_
|
||||
|
||||
#include "../VectorIndex.h"
|
||||
|
||||
#include "CommonUtils.h"
|
||||
#include "Dataset.h"
|
||||
#include "FineGrainedLock.h"
|
||||
#include "QueryResultSet.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
class NeighborhoodGraph
|
||||
{
|
||||
public:
|
||||
NeighborhoodGraph(): m_iTPTNumber(32),
|
||||
m_iTPTLeafSize(2000),
|
||||
m_iSamples(1000),
|
||||
m_numTopDimensionTPTSplit(5),
|
||||
m_iNeighborhoodSize(32),
|
||||
m_iNeighborhoodScale(2),
|
||||
m_iCEFScale(2),
|
||||
m_iRefineIter(0),
|
||||
m_iCEF(1000),
|
||||
m_iMaxCheckForRefineGraph(10000) {}
|
||||
|
||||
~NeighborhoodGraph() {}
|
||||
|
||||
virtual void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist) = 0;
|
||||
|
||||
virtual void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) = 0;
|
||||
|
||||
virtual float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr) = 0;
|
||||
|
||||
template <typename T>
|
||||
void BuildGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr)
|
||||
{
|
||||
std::cout << "build RNG graph!" << std::endl;
|
||||
|
||||
m_iGraphSize = index->GetNumSamples();
|
||||
m_iNeighborhoodSize = m_iNeighborhoodSize * m_iNeighborhoodScale;
|
||||
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
|
||||
m_dataUpdateLock.resize(m_iGraphSize);
|
||||
|
||||
if (m_iGraphSize < 1000) {
|
||||
RefineGraph<T>(index, idmap);
|
||||
std::cout << "Build RNG Graph end!" << std::endl;
|
||||
return;
|
||||
}
|
||||
|
||||
{
|
||||
COMMON::Dataset<float> NeighborhoodDists(m_iGraphSize, m_iNeighborhoodSize);
|
||||
std::vector<std::vector<int>> TptreeDataIndices(m_iTPTNumber, std::vector<int>(m_iGraphSize));
|
||||
std::vector<std::vector<std::pair<int, int>>> TptreeLeafNodes(m_iTPTNumber, std::vector<std::pair<int, int>>());
|
||||
|
||||
for (int i = 0; i < m_iGraphSize; i++)
|
||||
for (int j = 0; j < m_iNeighborhoodSize; j++)
|
||||
(NeighborhoodDists)[i][j] = MaxDist;
|
||||
|
||||
std::cout << "Parallel TpTree Partition begin " << std::endl;
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < m_iTPTNumber; i++)
|
||||
{
|
||||
Sleep(i * 100); std::srand(clock());
|
||||
for (int j = 0; j < m_iGraphSize; j++) TptreeDataIndices[i][j] = j;
|
||||
std::random_shuffle(TptreeDataIndices[i].begin(), TptreeDataIndices[i].end());
|
||||
PartitionByTptree<T>(index, TptreeDataIndices[i], 0, m_iGraphSize - 1, TptreeLeafNodes[i]);
|
||||
std::cout << "Finish Getting Leaves for Tree " << i << std::endl;
|
||||
}
|
||||
std::cout << "Parallel TpTree Partition done" << std::endl;
|
||||
|
||||
for (int i = 0; i < m_iTPTNumber; i++)
|
||||
{
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int j = 0; j < TptreeLeafNodes[i].size(); j++)
|
||||
{
|
||||
int start_index = TptreeLeafNodes[i][j].first;
|
||||
int end_index = TptreeLeafNodes[i][j].second;
|
||||
if (omp_get_thread_num() == 0) std::cout << "\rProcessing Tree " << i << ' ' << j * 100 / TptreeLeafNodes[i].size() << '%';
|
||||
for (int x = start_index; x < end_index; x++)
|
||||
{
|
||||
for (int y = x + 1; y <= end_index; y++)
|
||||
{
|
||||
int p1 = TptreeDataIndices[i][x];
|
||||
int p2 = TptreeDataIndices[i][y];
|
||||
float dist = index->ComputeDistance(index->GetSample(p1), index->GetSample(p2));
|
||||
if (idmap != nullptr) {
|
||||
p1 = (idmap->find(p1) == idmap->end()) ? p1 : idmap->at(p1);
|
||||
p2 = (idmap->find(p2) == idmap->end()) ? p2 : idmap->at(p2);
|
||||
}
|
||||
COMMON::Utils::AddNeighbor(p2, dist, (m_pNeighborhoodGraph)[p1], (NeighborhoodDists)[p1], m_iNeighborhoodSize);
|
||||
COMMON::Utils::AddNeighbor(p1, dist, (m_pNeighborhoodGraph)[p2], (NeighborhoodDists)[p2], m_iNeighborhoodSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
TptreeDataIndices[i].clear();
|
||||
TptreeLeafNodes[i].clear();
|
||||
std::cout << std::endl;
|
||||
}
|
||||
TptreeDataIndices.clear();
|
||||
TptreeLeafNodes.clear();
|
||||
}
|
||||
|
||||
if (m_iMaxCheckForRefineGraph > 0) {
|
||||
RefineGraph<T>(index, idmap);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RefineGraph(VectorIndex* index, const std::unordered_map<int, int>* idmap = nullptr)
|
||||
{
|
||||
m_iCEF *= m_iCEFScale;
|
||||
m_iMaxCheckForRefineGraph *= m_iCEFScale;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < m_iGraphSize; i++)
|
||||
{
|
||||
RefineNode<T>(index, i, false);
|
||||
if (i % 1000 == 0) std::cout << "\rRefine 1 " << (i * 100 / m_iGraphSize) << "%";
|
||||
}
|
||||
std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl;
|
||||
|
||||
m_iCEF /= m_iCEFScale;
|
||||
m_iMaxCheckForRefineGraph /= m_iCEFScale;
|
||||
m_iNeighborhoodSize /= m_iNeighborhoodScale;
|
||||
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < m_iGraphSize; i++)
|
||||
{
|
||||
RefineNode<T>(index, i, false);
|
||||
if (i % 1000 == 0) std::cout << "\rRefine 2 " << (i * 100 / m_iGraphSize) << "%";
|
||||
}
|
||||
std::cout << "Refine RNG, graph acc:" << GraphAccuracyEstimation(index, 100, idmap) << std::endl;
|
||||
|
||||
if (idmap != nullptr) {
|
||||
for (auto iter = idmap->begin(); iter != idmap->end(); iter++)
|
||||
if (iter->first < 0)
|
||||
{
|
||||
m_pNeighborhoodGraph[-1 - iter->first][m_iNeighborhoodSize - 1] = -2 - iter->second;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ErrorCode RefineGraph(VectorIndex* index, std::vector<int>& indices, std::vector<int>& reverseIndices,
|
||||
std::string graphFileName, const std::unordered_map<int, int>* idmap = nullptr)
|
||||
{
|
||||
int R = (int)indices.size();
|
||||
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < R; i++)
|
||||
{
|
||||
RefineNode<T>(index, indices[i], false);
|
||||
int* nodes = m_pNeighborhoodGraph[indices[i]];
|
||||
for (int j = 0; j < m_iNeighborhoodSize; j++)
|
||||
{
|
||||
if (nodes[j] < 0) nodes[j] = -1;
|
||||
else nodes[j] = reverseIndices[nodes[j]];
|
||||
}
|
||||
if (idmap == nullptr || idmap->find(-1 - indices[i]) == idmap->end()) continue;
|
||||
nodes[m_iNeighborhoodSize - 1] = -2 - idmap->at(-1 - indices[i]);
|
||||
}
|
||||
|
||||
std::ofstream graphOut(graphFileName, std::ios::binary);
|
||||
if (!graphOut.is_open()) return ErrorCode::FailedCreateFile;
|
||||
graphOut.write((char*)&R, sizeof(int));
|
||||
graphOut.write((char*)&m_iNeighborhoodSize, sizeof(int));
|
||||
for (int i = 0; i < R; i++) {
|
||||
graphOut.write((char*)m_pNeighborhoodGraph[indices[i]], sizeof(int) * m_iNeighborhoodSize);
|
||||
}
|
||||
graphOut.close();
|
||||
return ErrorCode::Success;
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void RefineNode(VectorIndex* index, const int node, bool updateNeighbors)
|
||||
{
|
||||
COMMON::QueryResultSet<T> query((const T*)index->GetSample(node), m_iCEF + 1);
|
||||
index->SearchIndex(query);
|
||||
RebuildNeighbors(index, node, m_pNeighborhoodGraph[node], query.GetResults(), m_iCEF + 1);
|
||||
|
||||
if (updateNeighbors) {
|
||||
// update neighbors
|
||||
for (int j = 0; j <= m_iCEF; j++)
|
||||
{
|
||||
BasicResult* item = query.GetResult(j);
|
||||
if (item->VID < 0) break;
|
||||
if (item->VID == node) continue;
|
||||
|
||||
std::lock_guard<std::mutex> lock(m_dataUpdateLock[item->VID]);
|
||||
InsertNeighbors(index, item->VID, node, item->Dist);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PartitionByTptree(VectorIndex* index, std::vector<int>& indices, const int first, const int last,
|
||||
std::vector<std::pair<int, int>> & leaves)
|
||||
{
|
||||
if (last - first <= m_iTPTLeafSize)
|
||||
{
|
||||
leaves.push_back(std::make_pair(first, last));
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<float> Mean(index->GetFeatureDim(), 0);
|
||||
|
||||
int iIteration = 100;
|
||||
int end = min(first + m_iSamples, last);
|
||||
int count = end - first + 1;
|
||||
// calculate the mean of each dimension
|
||||
for (int j = first; j <= end; j++)
|
||||
{
|
||||
const T* v = (const T*)index->GetSample(indices[j]);
|
||||
for (int k = 0; k < index->GetFeatureDim(); k++)
|
||||
{
|
||||
Mean[k] += v[k];
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < index->GetFeatureDim(); k++)
|
||||
{
|
||||
Mean[k] /= count;
|
||||
}
|
||||
std::vector<BasicResult> Variance;
|
||||
Variance.reserve(index->GetFeatureDim());
|
||||
for (int j = 0; j < index->GetFeatureDim(); j++)
|
||||
{
|
||||
Variance.push_back(BasicResult(j, 0));
|
||||
}
|
||||
// calculate the variance of each dimension
|
||||
for (int j = first; j <= end; j++)
|
||||
{
|
||||
const T* v = (const T*)index->GetSample(indices[j]);
|
||||
for (int k = 0; k < index->GetFeatureDim(); k++)
|
||||
{
|
||||
float dist = v[k] - Mean[k];
|
||||
Variance[k].Dist += dist*dist;
|
||||
}
|
||||
}
|
||||
std::sort(Variance.begin(), Variance.end(), COMMON::Compare);
|
||||
std::vector<int> indexs(m_numTopDimensionTPTSplit);
|
||||
std::vector<float> weight(m_numTopDimensionTPTSplit), bestweight(m_numTopDimensionTPTSplit);
|
||||
float bestvariance = Variance[index->GetFeatureDim() - 1].Dist;
|
||||
for (int i = 0; i < m_numTopDimensionTPTSplit; i++)
|
||||
{
|
||||
indexs[i] = Variance[index->GetFeatureDim() - 1 - i].VID;
|
||||
bestweight[i] = 0;
|
||||
}
|
||||
bestweight[0] = 1;
|
||||
float bestmean = Mean[indexs[0]];
|
||||
|
||||
std::vector<float> Val(count);
|
||||
for (int i = 0; i < iIteration; i++)
|
||||
{
|
||||
float sumweight = 0;
|
||||
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
|
||||
{
|
||||
weight[j] = float(rand() % 10000) / 5000.0f - 1.0f;
|
||||
sumweight += weight[j] * weight[j];
|
||||
}
|
||||
sumweight = sqrt(sumweight);
|
||||
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
|
||||
{
|
||||
weight[j] /= sumweight;
|
||||
}
|
||||
float mean = 0;
|
||||
for (int j = 0; j < count; j++)
|
||||
{
|
||||
Val[j] = 0;
|
||||
const T* v = (const T*)index->GetSample(indices[first + j]);
|
||||
for (int k = 0; k < m_numTopDimensionTPTSplit; k++)
|
||||
{
|
||||
Val[j] += weight[k] * v[indexs[k]];
|
||||
}
|
||||
mean += Val[j];
|
||||
}
|
||||
mean /= count;
|
||||
float var = 0;
|
||||
for (int j = 0; j < count; j++)
|
||||
{
|
||||
float dist = Val[j] - mean;
|
||||
var += dist * dist;
|
||||
}
|
||||
if (var > bestvariance)
|
||||
{
|
||||
bestvariance = var;
|
||||
bestmean = mean;
|
||||
for (int j = 0; j < m_numTopDimensionTPTSplit; j++)
|
||||
{
|
||||
bestweight[j] = weight[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
int i = first;
|
||||
int j = last;
|
||||
// decide which child one point belongs
|
||||
while (i <= j)
|
||||
{
|
||||
float val = 0;
|
||||
const T* v = (const T*)index->GetSample(indices[i]);
|
||||
for (int k = 0; k < m_numTopDimensionTPTSplit; k++)
|
||||
{
|
||||
val += bestweight[k] * v[indexs[k]];
|
||||
}
|
||||
if (val < bestmean)
|
||||
{
|
||||
i++;
|
||||
}
|
||||
else
|
||||
{
|
||||
std::swap(indices[i], indices[j]);
|
||||
j--;
|
||||
}
|
||||
}
|
||||
// if all the points in the node are equal,equally split the node into 2
|
||||
if ((i == first) || (i == last + 1))
|
||||
{
|
||||
i = (first + last + 1) / 2;
|
||||
}
|
||||
|
||||
Mean.clear();
|
||||
Variance.clear();
|
||||
Val.clear();
|
||||
indexs.clear();
|
||||
weight.clear();
|
||||
bestweight.clear();
|
||||
|
||||
PartitionByTptree<T>(index, indices, first, i - 1, leaves);
|
||||
PartitionByTptree<T>(index, indices, i, last, leaves);
|
||||
}
|
||||
}
|
||||
|
||||
bool LoadGraph(std::string sGraphFilename)
|
||||
{
|
||||
std::cout << "Load Graph From " << sGraphFilename << std::endl;
|
||||
FILE * fp = fopen(sGraphFilename.c_str(), "rb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
fread(&m_iGraphSize, sizeof(int), 1, fp);
|
||||
fread(&m_iNeighborhoodSize, sizeof(int), 1, fp);
|
||||
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize);
|
||||
m_dataUpdateLock.resize(m_iGraphSize);
|
||||
|
||||
for (int i = 0; i < m_iGraphSize; i++)
|
||||
{
|
||||
fread((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
|
||||
}
|
||||
fclose(fp);
|
||||
std::cout << "Load Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool LoadGraphFromMemory(char* pGraphMemFile)
|
||||
{
|
||||
m_iGraphSize = *((int*)pGraphMemFile);
|
||||
pGraphMemFile += sizeof(int);
|
||||
|
||||
m_iNeighborhoodSize = *((int*)pGraphMemFile);
|
||||
pGraphMemFile += sizeof(int);
|
||||
|
||||
m_pNeighborhoodGraph.Initialize(m_iGraphSize, m_iNeighborhoodSize, (int*)pGraphMemFile);
|
||||
m_dataUpdateLock.resize(m_iGraphSize);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SaveGraph(std::string sGraphFilename) const
|
||||
{
|
||||
std::cout << "Save Graph To " << sGraphFilename << std::endl;
|
||||
FILE *fp = fopen(sGraphFilename.c_str(), "wb");
|
||||
if (fp == NULL) return false;
|
||||
|
||||
fwrite(&m_iGraphSize, sizeof(int), 1, fp);
|
||||
fwrite(&m_iNeighborhoodSize, sizeof(int), 1, fp);
|
||||
for (int i = 0; i < m_iGraphSize; i++)
|
||||
{
|
||||
fwrite((m_pNeighborhoodGraph)[i], sizeof(int), m_iNeighborhoodSize, fp);
|
||||
}
|
||||
fclose(fp);
|
||||
std::cout << "Save Graph (" << m_iGraphSize << "," << m_iNeighborhoodSize << ") Finish!" << std::endl;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SaveGraphToMemory(void **pGraphMemFile, int64_t &len) {
|
||||
size_t size = sizeof(int) + sizeof(int) + sizeof(int) * m_iNeighborhoodSize * m_iGraphSize;
|
||||
char *mem = (char*)malloc(size);
|
||||
if (mem == NULL) return false;
|
||||
|
||||
auto ptr = mem;
|
||||
*(int*)ptr = m_iGraphSize;
|
||||
ptr += sizeof(int);
|
||||
|
||||
*(int*)ptr = m_iNeighborhoodSize;
|
||||
ptr += sizeof(int);
|
||||
|
||||
for (int i = 0; i < m_iGraphSize; i++)
|
||||
{
|
||||
memcpy(ptr, (m_pNeighborhoodGraph)[i], sizeof(int) * m_iNeighborhoodSize);
|
||||
ptr += sizeof(int) * m_iNeighborhoodSize;
|
||||
}
|
||||
*pGraphMemFile = mem;
|
||||
len = size;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
inline void AddBatch(int num) { m_pNeighborhoodGraph.AddBatch(num); m_iGraphSize += num; m_dataUpdateLock.resize(m_iGraphSize); }
|
||||
|
||||
inline int* operator[](int index) { return m_pNeighborhoodGraph[index]; }
|
||||
|
||||
inline const int* operator[](int index) const { return m_pNeighborhoodGraph[index]; }
|
||||
|
||||
inline void SetR(int rows) { m_pNeighborhoodGraph.SetR(rows); m_iGraphSize = rows; m_dataUpdateLock.resize(m_iGraphSize); }
|
||||
|
||||
inline int R() const { return m_iGraphSize; }
|
||||
|
||||
static std::shared_ptr<NeighborhoodGraph> CreateInstance(std::string type);
|
||||
|
||||
protected:
|
||||
// Graph structure
|
||||
int m_iGraphSize;
|
||||
COMMON::Dataset<int> m_pNeighborhoodGraph;
|
||||
COMMON::FineGrainedLock m_dataUpdateLock; // protect one row of the graph
|
||||
|
||||
public:
|
||||
int m_iTPTNumber, m_iTPTLeafSize, m_iSamples, m_numTopDimensionTPTSplit;
|
||||
int m_iNeighborhoodSize, m_iNeighborhoodScale, m_iCEFScale, m_iRefineIter, m_iCEF, m_iMaxCheckForRefineGraph;
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif
|
@ -1,96 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_QUERYRESULTSET_H_
|
||||
#define _SPTAG_COMMON_QUERYRESULTSET_H_
|
||||
|
||||
#include "../SearchQuery.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
|
||||
inline bool operator < (const BasicResult& lhs, const BasicResult& rhs)
|
||||
{
|
||||
return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
|
||||
}
|
||||
|
||||
|
||||
inline bool Compare(const BasicResult& lhs, const BasicResult& rhs)
|
||||
{
|
||||
return ((lhs.Dist < rhs.Dist) || ((lhs.Dist == rhs.Dist) && (lhs.VID < rhs.VID)));
|
||||
}
|
||||
|
||||
|
||||
// Space to save temporary answer, similar with TopKCache
|
||||
template<typename T>
|
||||
class QueryResultSet : public QueryResult
|
||||
{
|
||||
public:
|
||||
QueryResultSet(const T* _target, int _K) : QueryResult(_target, _K, false)
|
||||
{
|
||||
}
|
||||
|
||||
QueryResultSet(const QueryResultSet& other) : QueryResult(other)
|
||||
{
|
||||
}
|
||||
|
||||
inline void SetTarget(const T *p_target)
|
||||
{
|
||||
m_target = p_target;
|
||||
}
|
||||
|
||||
inline const T* GetTarget() const
|
||||
{
|
||||
return reinterpret_cast<const T*>(m_target);
|
||||
}
|
||||
|
||||
inline float worstDist() const
|
||||
{
|
||||
return m_results[0].Dist;
|
||||
}
|
||||
|
||||
bool AddPoint(const int index, float dist)
|
||||
{
|
||||
if (dist < m_results[0].Dist || (dist == m_results[0].Dist && index < m_results[0].VID))
|
||||
{
|
||||
m_results[0].VID = index;
|
||||
m_results[0].Dist = dist;
|
||||
Heapify(m_resultNum);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline void SortResult()
|
||||
{
|
||||
for (int i = m_resultNum - 1; i >= 0; i--)
|
||||
{
|
||||
std::swap(m_results[0], m_results[i]);
|
||||
Heapify(i);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void Heapify(int count)
|
||||
{
|
||||
int parent = 0, next = 1, maxidx = count - 1;
|
||||
while (next < maxidx)
|
||||
{
|
||||
if (m_results[next] < m_results[next + 1]) next++;
|
||||
if (m_results[parent] < m_results[next])
|
||||
{
|
||||
std::swap(m_results[next], m_results[parent]);
|
||||
parent = next;
|
||||
next = (parent << 1) + 1;
|
||||
}
|
||||
else break;
|
||||
}
|
||||
if (next == maxidx && m_results[parent] < m_results[next]) std::swap(m_results[parent], m_results[next]);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_QUERYRESULTSET_H_
|
@ -1,123 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_RNG_H_
|
||||
#define _SPTAG_COMMON_RNG_H_
|
||||
|
||||
#include "NeighborhoodGraph.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
class RelativeNeighborhoodGraph: public NeighborhoodGraph
|
||||
{
|
||||
public:
|
||||
void RebuildNeighbors(VectorIndex* index, const int node, int* nodes, const BasicResult* queryResults, const int numResults) {
|
||||
int count = 0;
|
||||
for (int j = 0; j < numResults && count < m_iNeighborhoodSize; j++) {
|
||||
const BasicResult& item = queryResults[j];
|
||||
if (item.VID < 0) break;
|
||||
if (item.VID == node) continue;
|
||||
|
||||
bool good = true;
|
||||
for (int k = 0; k < count; k++) {
|
||||
if (index->ComputeDistance(index->GetSample(nodes[k]), index->GetSample(item.VID)) <= item.Dist) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (good) nodes[count++] = item.VID;
|
||||
}
|
||||
for (int j = count; j < m_iNeighborhoodSize; j++) nodes[j] = -1;
|
||||
}
|
||||
|
||||
void InsertNeighbors(VectorIndex* index, const int node, int insertNode, float insertDist)
|
||||
{
|
||||
int* nodes = m_pNeighborhoodGraph[node];
|
||||
for (int k = 0; k < m_iNeighborhoodSize; k++)
|
||||
{
|
||||
int tmpNode = nodes[k];
|
||||
if (tmpNode < -1) continue;
|
||||
|
||||
if (tmpNode < 0)
|
||||
{
|
||||
bool good = true;
|
||||
for (int t = 0; t < k; t++) {
|
||||
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (good) {
|
||||
nodes[k] = insertNode;
|
||||
}
|
||||
break;
|
||||
}
|
||||
float tmpDist = index->ComputeDistance(index->GetSample(node), index->GetSample(tmpNode));
|
||||
if (insertDist < tmpDist || (insertDist == tmpDist && insertNode < tmpNode))
|
||||
{
|
||||
bool good = true;
|
||||
for (int t = 0; t < k; t++) {
|
||||
if (index->ComputeDistance(index->GetSample(insertNode), index->GetSample(nodes[t])) < insertDist) {
|
||||
good = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (good) {
|
||||
nodes[k] = insertNode;
|
||||
insertNode = tmpNode;
|
||||
insertDist = tmpDist;
|
||||
}
|
||||
else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float GraphAccuracyEstimation(VectorIndex* index, const int samples, const std::unordered_map<int, int>* idmap = nullptr)
|
||||
{
|
||||
int* correct = new int[samples];
|
||||
|
||||
#pragma omp parallel for schedule(dynamic)
|
||||
for (int i = 0; i < samples; i++)
|
||||
{
|
||||
int x = COMMON::Utils::rand_int(m_iGraphSize);
|
||||
//int x = i;
|
||||
COMMON::QueryResultSet<void> query(nullptr, m_iCEF);
|
||||
for (int y = 0; y < m_iGraphSize; y++)
|
||||
{
|
||||
if ((idmap != nullptr && idmap->find(y) != idmap->end())) continue;
|
||||
float dist = index->ComputeDistance(index->GetSample(x), index->GetSample(y));
|
||||
query.AddPoint(y, dist);
|
||||
}
|
||||
query.SortResult();
|
||||
int * exact_rng = new int[m_iNeighborhoodSize];
|
||||
RebuildNeighbors(index, x, exact_rng, query.GetResults(), m_iCEF);
|
||||
|
||||
correct[i] = 0;
|
||||
for (int j = 0; j < m_iNeighborhoodSize; j++) {
|
||||
if (exact_rng[j] == -1) {
|
||||
correct[i] += m_iNeighborhoodSize - j;
|
||||
break;
|
||||
}
|
||||
for (int k = 0; k < m_iNeighborhoodSize; k++)
|
||||
if ((m_pNeighborhoodGraph)[x][k] == exact_rng[j]) {
|
||||
correct[i]++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
delete[] exact_rng;
|
||||
}
|
||||
float acc = 0;
|
||||
for (int i = 0; i < samples; i++) acc += float(correct[i]);
|
||||
acc = acc / samples / m_iNeighborhoodSize;
|
||||
delete[] correct;
|
||||
return acc;
|
||||
}
|
||||
|
||||
};
|
||||
}
|
||||
}
|
||||
#endif
|
@ -1,185 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_WORKSPACE_H_
|
||||
#define _SPTAG_COMMON_WORKSPACE_H_
|
||||
|
||||
#include "CommonUtils.h"
|
||||
#include "Heap.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
// node type in the priority queue
|
||||
struct HeapCell
|
||||
{
|
||||
int node;
|
||||
float distance;
|
||||
|
||||
HeapCell(int _node = -1, float _distance = MaxDist) : node(_node), distance(_distance) {}
|
||||
|
||||
inline bool operator < (const HeapCell& rhs)
|
||||
{
|
||||
return distance < rhs.distance;
|
||||
}
|
||||
|
||||
inline bool operator > (const HeapCell& rhs)
|
||||
{
|
||||
return distance > rhs.distance;
|
||||
}
|
||||
};
|
||||
|
||||
class OptHashPosVector
|
||||
{
|
||||
protected:
|
||||
// Max loop number in one hash block.
|
||||
static const int m_maxLoop = 8;
|
||||
|
||||
// Max pool size.
|
||||
static const int m_poolSize = 8191;
|
||||
|
||||
// Could we use the second hash block.
|
||||
bool m_secondHash;
|
||||
|
||||
// Record 2 hash tables.
|
||||
// [0~m_poolSize + 1) is the first block.
|
||||
// [m_poolSize + 1, 2*(m_poolSize + 1)) is the second block;
|
||||
int m_hashTable[(m_poolSize + 1) * 2];
|
||||
|
||||
|
||||
inline unsigned hash_func2(int idx, int loop)
|
||||
{
|
||||
return ((unsigned)idx + loop) & m_poolSize;
|
||||
}
|
||||
|
||||
|
||||
inline unsigned hash_func(unsigned idx)
|
||||
{
|
||||
return ((unsigned)(idx * 99991) + _rotl(idx, 2) + 101) & m_poolSize;
|
||||
}
|
||||
|
||||
public:
|
||||
OptHashPosVector() {}
|
||||
|
||||
~OptHashPosVector() {}
|
||||
|
||||
|
||||
void Init(int size)
|
||||
{
|
||||
m_secondHash = true;
|
||||
clear();
|
||||
}
|
||||
|
||||
void clear()
|
||||
{
|
||||
if (!m_secondHash)
|
||||
{
|
||||
// Clear first block.
|
||||
memset(&m_hashTable[0], 0, sizeof(int)*(m_poolSize + 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Clear all blocks.
|
||||
memset(&m_hashTable[0], 0, 2 * sizeof(int) * (m_poolSize + 1));
|
||||
m_secondHash = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline bool CheckAndSet(int idx)
|
||||
{
|
||||
// Inner Index is begin from 1
|
||||
return _CheckAndSet(&m_hashTable[0], idx + 1) == 0;
|
||||
}
|
||||
|
||||
|
||||
inline int _CheckAndSet(int* hashTable, int idx)
|
||||
{
|
||||
unsigned index, loop;
|
||||
|
||||
// Get first hash position.
|
||||
index = hash_func(idx);
|
||||
for (loop = 0; loop < m_maxLoop; ++loop)
|
||||
{
|
||||
if (!hashTable[index])
|
||||
{
|
||||
// index first match and record it.
|
||||
hashTable[index] = idx;
|
||||
return 1;
|
||||
}
|
||||
if (hashTable[index] == idx)
|
||||
{
|
||||
// Hit this item in hash table.
|
||||
return 0;
|
||||
}
|
||||
// Get next hash position.
|
||||
index = hash_func2(index, loop);
|
||||
}
|
||||
|
||||
if (hashTable == &m_hashTable[0])
|
||||
{
|
||||
// Use second hash block.
|
||||
m_secondHash = true;
|
||||
return _CheckAndSet(&m_hashTable[m_poolSize + 1], idx);
|
||||
}
|
||||
|
||||
// Do not include this item.
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
// Variables for each single NN search
|
||||
struct WorkSpace
|
||||
{
|
||||
void Initialize(int maxCheck, int dataSize)
|
||||
{
|
||||
nodeCheckStatus.Init(dataSize);
|
||||
m_SPTQueue.Resize(maxCheck * 10);
|
||||
m_NGQueue.Resize(maxCheck * 30);
|
||||
|
||||
m_iNumberOfTreeCheckedLeaves = 0;
|
||||
m_iNumberOfCheckedLeaves = 0;
|
||||
m_iContinuousLimit = maxCheck / 64;
|
||||
m_iMaxCheck = maxCheck;
|
||||
m_iNumOfContinuousNoBetterPropagation = 0;
|
||||
}
|
||||
|
||||
void Reset(int maxCheck)
|
||||
{
|
||||
nodeCheckStatus.clear();
|
||||
m_SPTQueue.clear();
|
||||
m_NGQueue.clear();
|
||||
|
||||
m_iNumberOfTreeCheckedLeaves = 0;
|
||||
m_iNumberOfCheckedLeaves = 0;
|
||||
m_iContinuousLimit = maxCheck / 64;
|
||||
m_iMaxCheck = maxCheck;
|
||||
m_iNumOfContinuousNoBetterPropagation = 0;
|
||||
}
|
||||
|
||||
inline bool CheckAndSet(int idx)
|
||||
{
|
||||
return nodeCheckStatus.CheckAndSet(idx);
|
||||
}
|
||||
|
||||
OptHashPosVector nodeCheckStatus;
|
||||
//OptHashPosVector nodeCheckStatus;
|
||||
|
||||
// counter for dynamic pivoting
|
||||
int m_iNumOfContinuousNoBetterPropagation;
|
||||
int m_iContinuousLimit;
|
||||
int m_iNumberOfTreeCheckedLeaves;
|
||||
int m_iNumberOfCheckedLeaves;
|
||||
int m_iMaxCheck;
|
||||
|
||||
// Prioriy queue used for neighborhood graph
|
||||
Heap<HeapCell> m_NGQueue;
|
||||
|
||||
// Priority queue Used for BKT-Tree
|
||||
Heap<HeapCell> m_SPTQueue;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_WORKSPACE_H_
|
@ -1,43 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMON_WORKSPACEPOOL_H_
|
||||
#define _SPTAG_COMMON_WORKSPACEPOOL_H_
|
||||
|
||||
#include "WorkSpace.h"
|
||||
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace COMMON
|
||||
{
|
||||
|
||||
class WorkSpacePool
|
||||
{
|
||||
public:
|
||||
WorkSpacePool(int p_maxCheck, int p_vectorCount);
|
||||
|
||||
virtual ~WorkSpacePool();
|
||||
|
||||
std::shared_ptr<WorkSpace> Rent();
|
||||
|
||||
void Return(const std::shared_ptr<WorkSpace>& p_workSpace);
|
||||
|
||||
void Init(int size);
|
||||
|
||||
private:
|
||||
std::list<std::shared_ptr<WorkSpace>> m_workSpacePool;
|
||||
|
||||
std::mutex m_workSpacePoolMutex;
|
||||
|
||||
int m_maxCheck;
|
||||
|
||||
int m_vectorCount;
|
||||
};
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif // _SPTAG_COMMON_WORKSPACEPOOL_H_
|
@ -1,56 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_COMMONDATASTRUCTURE_H_
|
||||
#define _SPTAG_COMMONDATASTRUCTURE_H_
|
||||
|
||||
#include "Common.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
class ByteArray
|
||||
{
|
||||
public:
|
||||
ByteArray();
|
||||
|
||||
ByteArray(ByteArray&& p_right);
|
||||
|
||||
ByteArray(std::uint8_t* p_array, std::size_t p_length, bool p_transferOnwership);
|
||||
|
||||
ByteArray(std::uint8_t* p_array, std::size_t p_length, std::shared_ptr<std::uint8_t> p_dataHolder);
|
||||
|
||||
ByteArray(const ByteArray& p_right);
|
||||
|
||||
ByteArray& operator= (const ByteArray& p_right);
|
||||
|
||||
ByteArray& operator= (ByteArray&& p_right);
|
||||
|
||||
~ByteArray();
|
||||
|
||||
static ByteArray Alloc(std::size_t p_length);
|
||||
|
||||
std::uint8_t* Data() const;
|
||||
|
||||
std::size_t Length() const;
|
||||
|
||||
void SetData(std::uint8_t* p_array, std::size_t p_length);
|
||||
|
||||
std::shared_ptr<std::uint8_t> DataHolder() const;
|
||||
|
||||
void Clear();
|
||||
|
||||
const static ByteArray c_empty;
|
||||
|
||||
private:
|
||||
std::uint8_t* m_data;
|
||||
|
||||
std::size_t m_length;
|
||||
|
||||
// Notice this is holding an array. Set correct deleter for this.
|
||||
std::shared_ptr<std::uint8_t> m_dataHolder;
|
||||
};
|
||||
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_COMMONDATASTRUCTURE_H_
|
@ -1,57 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef DefineVectorValueType
|
||||
|
||||
DefineVectorValueType(Int8, std::int8_t)
|
||||
DefineVectorValueType(UInt8, std::uint8_t)
|
||||
DefineVectorValueType(Int16, std::int16_t)
|
||||
DefineVectorValueType(Float, float)
|
||||
|
||||
#endif // DefineVectorValueType
|
||||
|
||||
|
||||
#ifdef DefineDistCalcMethod
|
||||
|
||||
DefineDistCalcMethod(L2)
|
||||
DefineDistCalcMethod(Cosine)
|
||||
|
||||
#endif // DefineDistCalcMethod
|
||||
|
||||
|
||||
#ifdef DefineErrorCode
|
||||
|
||||
// 0x0000 ~ 0x0FFF General Status
|
||||
DefineErrorCode(Success, 0x0000)
|
||||
DefineErrorCode(Fail, 0x0001)
|
||||
DefineErrorCode(FailedOpenFile, 0x0002)
|
||||
DefineErrorCode(FailedCreateFile, 0x0003)
|
||||
DefineErrorCode(ParamNotFound, 0x0010)
|
||||
DefineErrorCode(FailedParseValue, 0x0011)
|
||||
|
||||
// 0x1000 ~ 0x1FFF Index Build Status
|
||||
|
||||
// 0x2000 ~ 0x2FFF Index Serve Status
|
||||
|
||||
// 0x3000 ~ 0x3FFF Helper Function Status
|
||||
DefineErrorCode(ReadIni_FailedParseSection, 0x3000)
|
||||
DefineErrorCode(ReadIni_FailedParseParam, 0x3001)
|
||||
DefineErrorCode(ReadIni_DuplicatedSection, 0x3002)
|
||||
DefineErrorCode(ReadIni_DuplicatedParam, 0x3003)
|
||||
|
||||
|
||||
// 0x4000 ~ 0x4FFF Socket Library Status
|
||||
DefineErrorCode(Socket_FailedResolveEndPoint, 0x4000)
|
||||
DefineErrorCode(Socket_FailedConnectToEndPoint, 0x4001)
|
||||
|
||||
|
||||
#endif // DefineErrorCode
|
||||
|
||||
|
||||
|
||||
#ifdef DefineIndexAlgo
|
||||
|
||||
DefineIndexAlgo(BKT)
|
||||
DefineIndexAlgo(KDT)
|
||||
|
||||
#endif // DefineIndexAlgo
|
@ -1,112 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_KDT_INDEX_H_
|
||||
#define _SPTAG_KDT_INDEX_H_
|
||||
|
||||
#include "../Common.h"
|
||||
#include "../VectorIndex.h"
|
||||
|
||||
#include "../Common/CommonUtils.h"
|
||||
#include "../Common/DistanceUtils.h"
|
||||
#include "../Common/QueryResultSet.h"
|
||||
#include "../Common/Dataset.h"
|
||||
#include "../Common/WorkSpace.h"
|
||||
#include "../Common/WorkSpacePool.h"
|
||||
#include "../Common/RelativeNeighborhoodGraph.h"
|
||||
#include "../Common/KDTree.h"
|
||||
#include "inc/Helper/StringConvert.h"
|
||||
#include "inc/Helper/SimpleIniReader.h"
|
||||
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <tbb/concurrent_unordered_set.h>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
namespace Helper
|
||||
{
|
||||
class IniReader;
|
||||
}
|
||||
|
||||
namespace KDT
|
||||
{
|
||||
template<typename T>
|
||||
class Index : public VectorIndex
|
||||
{
|
||||
private:
|
||||
// data points
|
||||
COMMON::Dataset<T> m_pSamples;
|
||||
|
||||
// KDT structures.
|
||||
COMMON::KDTree m_pTrees;
|
||||
|
||||
// Graph structure
|
||||
COMMON::RelativeNeighborhoodGraph m_pGraph;
|
||||
|
||||
std::string m_sKDTFilename;
|
||||
std::string m_sGraphFilename;
|
||||
std::string m_sDataPointsFilename;
|
||||
|
||||
std::mutex m_dataLock; // protect data and graph
|
||||
tbb::concurrent_unordered_set<int> m_deletedID;
|
||||
std::unique_ptr<COMMON::WorkSpacePool> m_workSpacePool;
|
||||
|
||||
int m_iNumberOfThreads;
|
||||
DistCalcMethod m_iDistCalcMethod;
|
||||
float(*m_fComputeDistance)(const T* pX, const T* pY, int length);
|
||||
|
||||
int m_iMaxCheck;
|
||||
int m_iThresholdOfNumberOfContinuousNoBetterPropagation;
|
||||
int m_iNumberOfInitialDynamicPivots;
|
||||
int m_iNumberOfOtherDynamicPivots;
|
||||
public:
|
||||
Index()
|
||||
{
|
||||
#define DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr) \
|
||||
VarName = DefaultValue; \
|
||||
|
||||
#include "inc/Core/KDT/ParameterDefinitionList.h"
|
||||
#undef DefineKDTParameter
|
||||
|
||||
m_fComputeDistance = COMMON::DistanceCalcSelector<T>(m_iDistCalcMethod);
|
||||
}
|
||||
|
||||
~Index() {}
|
||||
|
||||
inline int GetNumSamples() const { return m_pSamples.R(); }
|
||||
inline int GetFeatureDim() const { return m_pSamples.C(); }
|
||||
|
||||
inline int GetCurrMaxCheck() const { return m_iMaxCheck; }
|
||||
inline int GetNumThreads() const { return m_iNumberOfThreads; }
|
||||
inline DistCalcMethod GetDistCalcMethod() const { return m_iDistCalcMethod; }
|
||||
inline IndexAlgoType GetIndexAlgoType() const { return IndexAlgoType::KDT; }
|
||||
inline VectorValueType GetVectorValueType() const { return GetEnumValueType<T>(); }
|
||||
|
||||
inline float ComputeDistance(const void* pX, const void* pY) const { return m_fComputeDistance((const T*)pX, (const T*)pY, m_pSamples.C()); }
|
||||
inline const void* GetSample(const int idx) const { return (void*)m_pSamples[idx]; }
|
||||
|
||||
ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension);
|
||||
|
||||
ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen);
|
||||
ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs);
|
||||
|
||||
ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout);
|
||||
ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader);
|
||||
ErrorCode SearchIndex(QueryResult &p_query) const;
|
||||
ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension);
|
||||
ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum);
|
||||
|
||||
ErrorCode SetParameter(const char* p_param, const char* p_value);
|
||||
std::string GetParameter(const char* p_param) const;
|
||||
|
||||
private:
|
||||
ErrorCode RefineIndex(const std::string& p_folderPath);
|
||||
void SearchIndexWithDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space, const tbb::concurrent_unordered_set<int> &p_deleted) const;
|
||||
void SearchIndexWithoutDeleted(COMMON::QueryResultSet<T> &p_query, COMMON::WorkSpace &p_space) const;
|
||||
};
|
||||
} // namespace KDT
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_KDT_INDEX_H_
|
@ -1,34 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifdef DefineKDTParameter
|
||||
|
||||
// DefineKDTParameter(VarName, VarType, DefaultValue, RepresentStr)
|
||||
DefineKDTParameter(m_sKDTFilename, std::string, std::string("tree.bin"), "TreeFilePath")
|
||||
DefineKDTParameter(m_sGraphFilename, std::string, std::string("graph.bin"), "GraphFilePath")
|
||||
DefineKDTParameter(m_sDataPointsFilename, std::string, std::string("vectors.bin"), "VectorFilePath")
|
||||
|
||||
DefineKDTParameter(m_pTrees.m_iTreeNumber, int, 1L, "KDTNumber")
|
||||
DefineKDTParameter(m_pTrees.m_numTopDimensionKDTSplit, int, 5L, "NumTopDimensionKDTSplit")
|
||||
DefineKDTParameter(m_pTrees.m_iSamples, int, 100L, "NumSamplesKDTSplitConsideration")
|
||||
|
||||
DefineKDTParameter(m_pGraph.m_iTPTNumber, int, 32L, "TPTNumber")
|
||||
DefineKDTParameter(m_pGraph.m_iTPTLeafSize, int, 2000L, "TPTLeafSize")
|
||||
DefineKDTParameter(m_pGraph.m_numTopDimensionTPTSplit, int, 5L, "NumTopDimensionTPTSplit")
|
||||
|
||||
DefineKDTParameter(m_pGraph.m_iNeighborhoodSize, int, 32L, "NeighborhoodSize")
|
||||
DefineKDTParameter(m_pGraph.m_iNeighborhoodScale, int, 2L, "GraphNeighborhoodScale")
|
||||
DefineKDTParameter(m_pGraph.m_iCEFScale, int, 2L, "GraphCEFScale")
|
||||
DefineKDTParameter(m_pGraph.m_iRefineIter, int, 0L, "RefineIterations")
|
||||
DefineKDTParameter(m_pGraph.m_iCEF, int, 1000L, "CEF")
|
||||
DefineKDTParameter(m_pGraph.m_iMaxCheckForRefineGraph, int, 10000L, "MaxCheckForRefineGraph")
|
||||
|
||||
DefineKDTParameter(m_iNumberOfThreads, int, 1L, "NumberOfThreads")
|
||||
DefineKDTParameter(m_iDistCalcMethod, SPTAG::DistCalcMethod, SPTAG::DistCalcMethod::Cosine, "DistCalcMethod")
|
||||
|
||||
DefineKDTParameter(m_iMaxCheck, int, 8192L, "MaxCheck")
|
||||
DefineKDTParameter(m_iThresholdOfNumberOfContinuousNoBetterPropagation, int, 3L, "ThresholdOfNumberOfContinuousNoBetterPropagation")
|
||||
DefineKDTParameter(m_iNumberOfInitialDynamicPivots, int, 50L, "NumberOfInitialDynamicPivots")
|
||||
DefineKDTParameter(m_iNumberOfOtherDynamicPivots, int, 4L, "NumberOfOtherDynamicPivots")
|
||||
|
||||
#endif
|
@ -1,114 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_METADATASET_H_
|
||||
#define _SPTAG_METADATASET_H_
|
||||
|
||||
#include "CommonDataStructure.h"
|
||||
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
class MetadataSet
|
||||
{
|
||||
public:
|
||||
MetadataSet();
|
||||
|
||||
virtual ~MetadataSet();
|
||||
|
||||
virtual ByteArray GetMetadata(IndexType p_vectorID) const = 0;
|
||||
|
||||
virtual SizeType Count() const = 0;
|
||||
|
||||
virtual bool Available() const = 0;
|
||||
|
||||
virtual void AddBatch(MetadataSet& data) = 0;
|
||||
|
||||
virtual ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile) = 0;
|
||||
|
||||
virtual ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len) = 0;
|
||||
|
||||
virtual ErrorCode LoadMetadataFromMemory(void *pGraphMemFile) = 0;
|
||||
|
||||
virtual ErrorCode RefineMetadata(std::vector<int>& indices, const std::string& p_folderPath);
|
||||
|
||||
static ErrorCode MetaCopy(const std::string& p_src, const std::string& p_dst);
|
||||
};
|
||||
|
||||
|
||||
class FileMetadataSet : public MetadataSet
|
||||
{
|
||||
public:
|
||||
FileMetadataSet(const std::string& p_metaFile, const std::string& p_metaindexFile);
|
||||
|
||||
~FileMetadataSet();
|
||||
|
||||
ByteArray GetMetadata(IndexType p_vectorID) const;
|
||||
|
||||
SizeType Count() const;
|
||||
|
||||
bool Available() const;
|
||||
|
||||
void AddBatch(MetadataSet& data);
|
||||
|
||||
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
|
||||
|
||||
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len);
|
||||
|
||||
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
|
||||
private:
|
||||
std::ifstream* m_fp = nullptr;
|
||||
|
||||
std::vector<std::uint64_t> m_pOffsets;
|
||||
|
||||
SizeType m_count;
|
||||
|
||||
std::string m_metaFile;
|
||||
|
||||
std::string m_metaindexFile;
|
||||
|
||||
std::vector<std::uint8_t> m_newdata;
|
||||
};
|
||||
|
||||
|
||||
class MemMetadataSet : public MetadataSet
|
||||
{
|
||||
public:
|
||||
MemMetadataSet() = default;
|
||||
|
||||
MemMetadataSet(ByteArray p_metadata, ByteArray p_offsets, SizeType p_count);
|
||||
|
||||
~MemMetadataSet();
|
||||
|
||||
ByteArray GetMetadata(IndexType p_vectorID) const;
|
||||
|
||||
SizeType Count() const;
|
||||
|
||||
bool Available() const;
|
||||
|
||||
void AddBatch(MetadataSet& data);
|
||||
|
||||
ErrorCode SaveMetadata(const std::string& p_metaFile, const std::string& p_metaindexFile);
|
||||
|
||||
ErrorCode SaveMetadataToMemory(void **pGraphMemFile, int64_t &len);
|
||||
|
||||
ErrorCode LoadMetadataFromMemory(void *pGraphMemFile);
|
||||
private:
|
||||
std::vector<std::uint64_t> m_offsets;
|
||||
|
||||
SizeType m_count;
|
||||
|
||||
ByteArray m_metadataHolder;
|
||||
|
||||
ByteArray m_offsetHolder;
|
||||
|
||||
std::vector<std::uint8_t> m_newdata;
|
||||
};
|
||||
|
||||
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_METADATASET_H_
|
@ -1,239 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SEARCHQUERY_H_
|
||||
#define _SPTAG_SEARCHQUERY_H_
|
||||
|
||||
#include "CommonDataStructure.h"
|
||||
|
||||
#include <cstring>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
struct BasicResult
|
||||
{
|
||||
int VID;
|
||||
float Dist;
|
||||
|
||||
BasicResult() : VID(-1), Dist(MaxDist) {}
|
||||
|
||||
BasicResult(int p_vid, float p_dist) : VID(p_vid), Dist(p_dist) {}
|
||||
};
|
||||
|
||||
|
||||
// Space to save temporary answer, similar with TopKCache
|
||||
class QueryResult
|
||||
{
|
||||
public:
|
||||
typedef BasicResult* iterator;
|
||||
typedef const BasicResult* const_iterator;
|
||||
|
||||
QueryResult()
|
||||
: m_target(nullptr),
|
||||
m_resultNum(0),
|
||||
m_withMeta(false)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
QueryResult(const void* p_target, int p_resultNum, bool p_withMeta)
|
||||
: m_target(nullptr),
|
||||
m_resultNum(0),
|
||||
m_withMeta(false)
|
||||
{
|
||||
Init(p_target, p_resultNum, p_withMeta);
|
||||
}
|
||||
|
||||
|
||||
QueryResult(const void* p_target, int p_resultNum, std::vector<BasicResult>& p_results)
|
||||
: m_target(p_target),
|
||||
m_resultNum(p_resultNum),
|
||||
m_withMeta(false)
|
||||
{
|
||||
p_results.resize(p_resultNum);
|
||||
m_results.reset(p_results.data());
|
||||
}
|
||||
|
||||
|
||||
QueryResult(const QueryResult& p_other)
|
||||
: m_target(p_other.m_target),
|
||||
m_resultNum(p_other.m_resultNum),
|
||||
m_withMeta(p_other.m_withMeta)
|
||||
{
|
||||
if (m_resultNum > 0)
|
||||
{
|
||||
m_results.reset(new BasicResult[m_resultNum]);
|
||||
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
|
||||
|
||||
if (m_withMeta)
|
||||
{
|
||||
m_metadatas.reset(new ByteArray[m_resultNum]);
|
||||
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
QueryResult& operator=(const QueryResult& p_other)
|
||||
{
|
||||
Init(p_other.m_target, p_other.m_resultNum, p_other.m_withMeta);
|
||||
|
||||
if (m_resultNum > 0)
|
||||
{
|
||||
std::memcpy(m_results.get(), p_other.m_results.get(), sizeof(BasicResult) * m_resultNum);
|
||||
if (m_withMeta)
|
||||
{
|
||||
std::copy(p_other.m_metadatas.get(), p_other.m_metadatas.get() + m_resultNum, m_metadatas.get());
|
||||
}
|
||||
}
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
~QueryResult()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
inline void Init(const void* p_target, int p_resultNum, bool p_withMeta)
|
||||
{
|
||||
m_target = p_target;
|
||||
if (p_resultNum > m_resultNum)
|
||||
{
|
||||
m_results.reset(new BasicResult[p_resultNum]);
|
||||
}
|
||||
|
||||
if (p_withMeta && (!m_withMeta || p_resultNum > m_resultNum))
|
||||
{
|
||||
m_metadatas.reset(new ByteArray[p_resultNum]);
|
||||
}
|
||||
|
||||
m_resultNum = p_resultNum;
|
||||
m_withMeta = p_withMeta;
|
||||
}
|
||||
|
||||
|
||||
inline int GetResultNum() const
|
||||
{
|
||||
return m_resultNum;
|
||||
}
|
||||
|
||||
|
||||
inline const void* GetTarget()
|
||||
{
|
||||
return m_target;
|
||||
}
|
||||
|
||||
|
||||
inline void SetTarget(const void* p_target)
|
||||
{
|
||||
m_target = p_target;
|
||||
}
|
||||
|
||||
|
||||
inline BasicResult* GetResult(int i) const
|
||||
{
|
||||
return i < m_resultNum ? m_results.get() + i : nullptr;
|
||||
}
|
||||
|
||||
|
||||
inline void SetResult(int p_index, int p_VID, float p_dist)
|
||||
{
|
||||
if (p_index < m_resultNum)
|
||||
{
|
||||
m_results[p_index].VID = p_VID;
|
||||
m_results[p_index].Dist = p_dist;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline BasicResult* GetResults() const
|
||||
{
|
||||
return m_results.get();
|
||||
}
|
||||
|
||||
|
||||
inline bool WithMeta() const
|
||||
{
|
||||
return m_withMeta;
|
||||
}
|
||||
|
||||
|
||||
inline const ByteArray& GetMetadata(int p_index) const
|
||||
{
|
||||
if (p_index < m_resultNum && m_withMeta)
|
||||
{
|
||||
return m_metadatas[p_index];
|
||||
}
|
||||
|
||||
return ByteArray::c_empty;
|
||||
}
|
||||
|
||||
|
||||
inline void SetMetadata(int p_index, ByteArray p_metadata)
|
||||
{
|
||||
if (p_index < m_resultNum && m_withMeta)
|
||||
{
|
||||
m_metadatas[p_index] = std::move(p_metadata);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
inline void Reset()
|
||||
{
|
||||
for (int i = 0; i < m_resultNum; i++)
|
||||
{
|
||||
m_results[i].VID = -1;
|
||||
m_results[i].Dist = MaxDist;
|
||||
}
|
||||
|
||||
if (m_withMeta)
|
||||
{
|
||||
for (int i = 0; i < m_resultNum; i++)
|
||||
{
|
||||
m_metadatas[i].Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
iterator begin()
|
||||
{
|
||||
return m_results.get();
|
||||
}
|
||||
|
||||
|
||||
iterator end()
|
||||
{
|
||||
return m_results.get() + m_resultNum;
|
||||
}
|
||||
|
||||
|
||||
const_iterator begin() const
|
||||
{
|
||||
return m_results.get();
|
||||
}
|
||||
|
||||
|
||||
const_iterator end() const
|
||||
{
|
||||
return m_results.get() + m_resultNum;
|
||||
}
|
||||
|
||||
|
||||
protected:
|
||||
const void* m_target;
|
||||
|
||||
int m_resultNum;
|
||||
|
||||
bool m_withMeta;
|
||||
|
||||
std::unique_ptr<BasicResult[]> m_results;
|
||||
|
||||
std::unique_ptr<ByteArray[]> m_metadatas;
|
||||
};
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SEARCHQUERY_H_
|
@ -1,94 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_VECTORINDEX_H_
|
||||
#define _SPTAG_VECTORINDEX_H_
|
||||
|
||||
#include "Common.h"
|
||||
#include "SearchQuery.h"
|
||||
#include "VectorSet.h"
|
||||
#include "MetadataSet.h"
|
||||
#include "inc/Helper/SimpleIniReader.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
class VectorIndex
|
||||
{
|
||||
public:
|
||||
VectorIndex();
|
||||
|
||||
virtual ~VectorIndex();
|
||||
|
||||
virtual ErrorCode SaveIndex(const std::string& p_folderPath, std::ofstream& p_configout) = 0;
|
||||
|
||||
virtual ErrorCode LoadIndex(const std::string& p_folderPath, Helper::IniReader& p_reader) = 0;
|
||||
|
||||
virtual ErrorCode SaveIndexToMemory(std::vector<void*>& p_indexBlobs, std::vector<int64_t>& p_indexBlobsLen) = 0;
|
||||
|
||||
virtual ErrorCode LoadIndexFromMemory(const std::vector<void*>& p_indexBlobs) = 0;
|
||||
|
||||
virtual ErrorCode BuildIndex(const void* p_data, int p_vectorNum, int p_dimension) = 0;
|
||||
|
||||
virtual ErrorCode SearchIndex(QueryResult& p_results) const = 0;
|
||||
|
||||
virtual ErrorCode AddIndex(const void* p_vectors, int p_vectorNum, int p_dimension) = 0;
|
||||
|
||||
virtual ErrorCode DeleteIndex(const void* p_vectors, int p_vectorNum) = 0;
|
||||
|
||||
//virtual ErrorCode AddIndexWithID(const void* p_vector, const int& p_id) = 0;
|
||||
|
||||
//virtual ErrorCode DeleteIndexWithID(const void* p_vector, const int& p_id) = 0;
|
||||
|
||||
virtual float ComputeDistance(const void* pX, const void* pY) const = 0;
|
||||
virtual const void* GetSample(const int idx) const = 0;
|
||||
virtual int GetFeatureDim() const = 0;
|
||||
virtual int GetNumSamples() const = 0;
|
||||
|
||||
virtual DistCalcMethod GetDistCalcMethod() const = 0;
|
||||
virtual IndexAlgoType GetIndexAlgoType() const = 0;
|
||||
virtual VectorValueType GetVectorValueType() const = 0;
|
||||
virtual int GetNumThreads() const = 0;
|
||||
|
||||
virtual std::string GetParameter(const char* p_param) const = 0;
|
||||
virtual ErrorCode SetParameter(const char* p_param, const char* p_value) = 0;
|
||||
|
||||
virtual ErrorCode LoadIndex(const std::string& p_folderPath);
|
||||
|
||||
virtual ErrorCode SaveIndex(const std::string& p_folderPath);
|
||||
|
||||
virtual ErrorCode BuildIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
|
||||
|
||||
virtual ErrorCode SearchIndex(const void* p_vector, int p_neighborCount, std::vector<BasicResult>& p_results) const;
|
||||
|
||||
virtual ErrorCode AddIndex(std::shared_ptr<VectorSet> p_vectorSet, std::shared_ptr<MetadataSet> p_metadataSet);
|
||||
|
||||
virtual std::string GetParameter(const std::string& p_param) const;
|
||||
virtual ErrorCode SetParameter(const std::string& p_param, const std::string& p_value);
|
||||
|
||||
virtual ByteArray GetMetadata(IndexType p_vectorID) const;
|
||||
virtual void SetMetadata(const std::string& p_metadataFilePath, const std::string& p_metadataIndexPath);
|
||||
|
||||
virtual std::string GetIndexName() const
|
||||
{
|
||||
if (m_sIndexName == "")
|
||||
return Helper::Convert::ConvertToString(GetIndexAlgoType());
|
||||
return m_sIndexName;
|
||||
}
|
||||
virtual void SetIndexName(std::string p_name) { m_sIndexName = p_name; }
|
||||
|
||||
static std::shared_ptr<VectorIndex> CreateInstance(IndexAlgoType p_algo, VectorValueType p_valuetype);
|
||||
|
||||
static ErrorCode MergeIndex(const char* p_indexFilePath1, const char* p_indexFilePath2);
|
||||
|
||||
static ErrorCode LoadIndex(const std::string& p_loaderFilePath, std::shared_ptr<VectorIndex>& p_vectorIndex);
|
||||
|
||||
protected:
|
||||
std::string m_sIndexName;
|
||||
std::shared_ptr<MetadataSet> m_pMetadata;
|
||||
};
|
||||
|
||||
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_VECTORINDEX_H_
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_VECTORSET_H_
|
||||
#define _SPTAG_VECTORSET_H_
|
||||
|
||||
#include "CommonDataStructure.h"
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
|
||||
class VectorSet
|
||||
{
|
||||
public:
|
||||
VectorSet();
|
||||
|
||||
virtual ~VectorSet();
|
||||
|
||||
virtual VectorValueType GetValueType() const = 0;
|
||||
|
||||
virtual void* GetVector(IndexType p_vectorID) const = 0;
|
||||
|
||||
virtual void* GetData() const = 0;
|
||||
|
||||
virtual SizeType Dimension() const = 0;
|
||||
|
||||
virtual SizeType Count() const = 0;
|
||||
|
||||
virtual bool Available() const = 0;
|
||||
|
||||
virtual ErrorCode Save(const std::string& p_vectorFile) const = 0;
|
||||
};
|
||||
|
||||
|
||||
class BasicVectorSet : public VectorSet
|
||||
{
|
||||
public:
|
||||
BasicVectorSet(const ByteArray& p_bytesArray,
|
||||
VectorValueType p_valueType,
|
||||
SizeType p_dimension,
|
||||
SizeType p_vectorCount);
|
||||
|
||||
virtual ~BasicVectorSet();
|
||||
|
||||
virtual VectorValueType GetValueType() const;
|
||||
|
||||
virtual void* GetVector(IndexType p_vectorID) const;
|
||||
|
||||
virtual void* GetData() const;
|
||||
|
||||
virtual SizeType Dimension() const;
|
||||
|
||||
virtual SizeType Count() const;
|
||||
|
||||
virtual bool Available() const;
|
||||
|
||||
virtual ErrorCode Save(const std::string& p_vectorFile) const;
|
||||
|
||||
private:
|
||||
ByteArray m_data;
|
||||
|
||||
VectorValueType m_valueType;
|
||||
|
||||
SizeType m_dimension;
|
||||
|
||||
SizeType m_vectorCount;
|
||||
|
||||
SizeType m_perVectorDataSize;
|
||||
};
|
||||
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_VECTORSET_H_
|
@ -1,253 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_HELPER_ARGUMENTSPARSER_H_
|
||||
#define _SPTAG_HELPER_ARGUMENTSPARSER_H_
|
||||
|
||||
#include "inc/Helper/StringConvert.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Helper
|
||||
{
|
||||
|
||||
class ArgumentsParser
|
||||
{
|
||||
public:
|
||||
ArgumentsParser();
|
||||
|
||||
virtual ~ArgumentsParser();
|
||||
|
||||
virtual bool Parse(int p_argc, char** p_args);
|
||||
|
||||
virtual void PrintHelp();
|
||||
|
||||
protected:
|
||||
class IArgument
|
||||
{
|
||||
public:
|
||||
IArgument();
|
||||
|
||||
virtual ~IArgument();
|
||||
|
||||
virtual bool ParseValue(int& p_restArgc, char** (&p_args)) = 0;
|
||||
|
||||
virtual void PrintDescription(FILE* p_output) = 0;
|
||||
|
||||
virtual bool IsRequiredButNotSet() const = 0;
|
||||
};
|
||||
|
||||
|
||||
template<typename DataType>
|
||||
class ArgumentT : public IArgument
|
||||
{
|
||||
public:
|
||||
ArgumentT(DataType& p_target,
|
||||
const std::string& p_representStringShort,
|
||||
const std::string& p_representString,
|
||||
const std::string& p_description,
|
||||
bool p_followedValue,
|
||||
const DataType& p_switchAsValue,
|
||||
bool p_isRequired)
|
||||
: m_value(p_target),
|
||||
m_representStringShort(p_representStringShort),
|
||||
m_representString(p_representString),
|
||||
m_description(p_description),
|
||||
m_followedValue(p_followedValue),
|
||||
c_switchAsValue(p_switchAsValue),
|
||||
m_isRequired(p_isRequired),
|
||||
m_isSet(false)
|
||||
{
|
||||
}
|
||||
|
||||
virtual ~ArgumentT()
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
virtual bool ParseValue(int& p_restArgc, char** (&p_args))
|
||||
{
|
||||
if (0 == p_restArgc)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (0 != strcmp(*p_args, m_representString.c_str())
|
||||
&& 0 != strcmp(*p_args, m_representStringShort.c_str()))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (!m_followedValue)
|
||||
{
|
||||
m_value = c_switchAsValue;
|
||||
--p_restArgc;
|
||||
++p_args;
|
||||
m_isSet = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (p_restArgc < 2)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
DataType tmp;
|
||||
if (!Helper::Convert::ConvertStringTo(p_args[1], tmp))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
m_value = std::move(tmp);
|
||||
|
||||
p_restArgc -= 2;
|
||||
p_args += 2;
|
||||
m_isSet = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
virtual void PrintDescription(FILE* p_output)
|
||||
{
|
||||
std::size_t padding = 30;
|
||||
if (!m_representStringShort.empty())
|
||||
{
|
||||
fprintf(p_output, "%s", m_representStringShort.c_str());
|
||||
padding -= m_representStringShort.size();
|
||||
}
|
||||
|
||||
if (!m_representString.empty())
|
||||
{
|
||||
if (!m_representStringShort.empty())
|
||||
{
|
||||
fprintf(p_output, ", ");
|
||||
padding -= 2;
|
||||
}
|
||||
|
||||
fprintf(p_output, "%s", m_representString.c_str());
|
||||
padding -= m_representString.size();
|
||||
}
|
||||
|
||||
if (m_followedValue)
|
||||
{
|
||||
fprintf(p_output, " <value>");
|
||||
padding -= 8;
|
||||
}
|
||||
|
||||
while (padding-- > 0)
|
||||
{
|
||||
fputc(' ', p_output);
|
||||
}
|
||||
|
||||
fprintf(p_output, "%s", m_description.c_str());
|
||||
}
|
||||
|
||||
|
||||
virtual bool IsRequiredButNotSet() const
|
||||
{
|
||||
return m_isRequired && !m_isSet;
|
||||
}
|
||||
|
||||
private:
|
||||
DataType & m_value;
|
||||
|
||||
std::string m_representStringShort;
|
||||
|
||||
std::string m_representString;
|
||||
|
||||
std::string m_description;
|
||||
|
||||
bool m_followedValue;
|
||||
|
||||
const DataType c_switchAsValue;
|
||||
|
||||
bool m_isRequired;
|
||||
|
||||
bool m_isSet;
|
||||
};
|
||||
|
||||
|
||||
template<typename DataType>
|
||||
void AddRequiredOption(DataType& p_target,
|
||||
const std::string& p_representStringShort,
|
||||
const std::string& p_representString,
|
||||
const std::string& p_description)
|
||||
{
|
||||
m_arguments.emplace_back(std::shared_ptr<IArgument>(
|
||||
new ArgumentT<DataType>(p_target,
|
||||
p_representStringShort,
|
||||
p_representString,
|
||||
p_description,
|
||||
true,
|
||||
DataType(),
|
||||
true)));
|
||||
}
|
||||
|
||||
|
||||
template<typename DataType>
|
||||
void AddOptionalOption(DataType& p_target,
|
||||
const std::string& p_representStringShort,
|
||||
const std::string& p_representString,
|
||||
const std::string& p_description)
|
||||
{
|
||||
m_arguments.emplace_back(std::shared_ptr<IArgument>(
|
||||
new ArgumentT<DataType>(p_target,
|
||||
p_representStringShort,
|
||||
p_representString,
|
||||
p_description,
|
||||
true,
|
||||
DataType(),
|
||||
false)));
|
||||
}
|
||||
|
||||
|
||||
template<typename DataType>
|
||||
void AddRequiredSwitch(DataType& p_target,
|
||||
const std::string& p_representStringShort,
|
||||
const std::string& p_representString,
|
||||
const std::string& p_description,
|
||||
const DataType& p_switchAsValue)
|
||||
{
|
||||
m_arguments.emplace_back(std::shared_ptr<IArgument>(
|
||||
new ArgumentT<DataType>(p_target,
|
||||
p_representStringShort,
|
||||
p_representString,
|
||||
p_description,
|
||||
false,
|
||||
p_switchAsValue,
|
||||
true)));
|
||||
}
|
||||
|
||||
|
||||
template<typename DataType>
|
||||
void AddOptionalSwitch(DataType& p_target,
|
||||
const std::string& p_representStringShort,
|
||||
const std::string& p_representString,
|
||||
const std::string& p_description,
|
||||
const DataType& p_switchAsValue)
|
||||
{
|
||||
m_arguments.emplace_back(std::shared_ptr<IArgument>(
|
||||
new ArgumentT<DataType>(p_target,
|
||||
p_representStringShort,
|
||||
p_representString,
|
||||
p_description,
|
||||
false,
|
||||
p_switchAsValue,
|
||||
false)));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<IArgument>> m_arguments;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Helper
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_HELPER_ARGUMENTSPARSER_H_
|
@ -1,33 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_HELPER_BASE64ENCODE_H_
|
||||
#define _SPTAG_HELPER_BASE64ENCODE_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstddef>
|
||||
#include <ostream>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Helper
|
||||
{
|
||||
namespace Base64
|
||||
{
|
||||
|
||||
bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, char* p_out, std::size_t& p_outLen);
|
||||
|
||||
bool Encode(const std::uint8_t* p_in, std::size_t p_inLen, std::ostream& p_out, std::size_t& p_outLen);
|
||||
|
||||
bool Decode(const char* p_in, std::size_t p_inLen, std::uint8_t* p_out, std::size_t& p_outLen);
|
||||
|
||||
std::size_t CapacityForEncode(std::size_t p_inLen);
|
||||
|
||||
std::size_t CapacityForDecode(std::size_t p_inLen);
|
||||
|
||||
|
||||
} // namespace Base64
|
||||
} // namespace Helper
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_HELPER_BASE64ENCODE_H_
|
@ -1,40 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_HELPER_COMMONHELPER_H_
|
||||
#define _SPTAG_HELPER_COMMONHELPER_H_
|
||||
|
||||
#include "../Core/Common.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cctype>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <cerrno>
|
||||
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Helper
|
||||
{
|
||||
namespace StrUtils
|
||||
{
|
||||
|
||||
void ToLowerInPlace(std::string& p_str);
|
||||
|
||||
std::vector<std::string> SplitString(const std::string& p_str, const std::string& p_separator);
|
||||
|
||||
std::pair<const char*, const char*> FindTrimmedSegment(const char* p_begin,
|
||||
const char* p_end,
|
||||
const std::function<bool(char)>& p_isSkippedChar);
|
||||
|
||||
bool StartsWith(const char* p_str, const char* p_prefix);
|
||||
|
||||
bool StrEqualIgnoreCase(const char* p_left, const char* p_right);
|
||||
|
||||
} // namespace StrUtils
|
||||
} // namespace Helper
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_HELPER_COMMONHELPER_H_
|
@ -1,97 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_HELPER_CONCURRENT_H_
|
||||
#define _SPTAG_HELPER_CONCURRENT_H_
|
||||
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Helper
|
||||
{
|
||||
namespace Concurrent
|
||||
{
|
||||
|
||||
class SpinLock
|
||||
{
|
||||
public:
|
||||
SpinLock() = default;
|
||||
|
||||
void Lock() noexcept
|
||||
{
|
||||
while (m_lock.test_and_set(std::memory_order_acquire))
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
void Unlock() noexcept
|
||||
{
|
||||
m_lock.clear(std::memory_order_release);
|
||||
}
|
||||
|
||||
SpinLock(const SpinLock&) = delete;
|
||||
SpinLock& operator = (const SpinLock&) = delete;
|
||||
|
||||
private:
|
||||
std::atomic_flag m_lock = ATOMIC_FLAG_INIT;
|
||||
};
|
||||
|
||||
template<typename Lock>
|
||||
class LockGuard {
|
||||
public:
|
||||
LockGuard(Lock& lock) noexcept
|
||||
: m_lock(lock) {
|
||||
lock.Lock();
|
||||
}
|
||||
|
||||
LockGuard(Lock& lock, std::adopt_lock_t) noexcept
|
||||
: m_lock(lock) {}
|
||||
|
||||
~LockGuard() {
|
||||
m_lock.Unlock();
|
||||
}
|
||||
|
||||
LockGuard(const LockGuard&) = delete;
|
||||
LockGuard& operator=(const LockGuard&) = delete;
|
||||
|
||||
private:
|
||||
Lock& m_lock;
|
||||
};
|
||||
|
||||
|
||||
class WaitSignal
|
||||
{
|
||||
public:
|
||||
WaitSignal();
|
||||
|
||||
WaitSignal(std::uint32_t p_unfinished);
|
||||
|
||||
~WaitSignal();
|
||||
|
||||
void Reset(std::uint32_t p_unfinished);
|
||||
|
||||
void Wait();
|
||||
|
||||
void FinishOne();
|
||||
|
||||
private:
|
||||
std::atomic<std::uint32_t> m_unfinished;
|
||||
|
||||
std::atomic_bool m_isWaiting;
|
||||
|
||||
std::mutex m_mutex;
|
||||
|
||||
std::condition_variable m_cv;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Base64
|
||||
} // namespace Helper
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_HELPER_CONCURRENT_H_
|
@ -1,97 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_HELPER_INIREADER_H_
|
||||
#define _SPTAG_HELPER_INIREADER_H_
|
||||
|
||||
#include "../Core/Common.h"
|
||||
#include "StringConvert.h"
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Helper
|
||||
{
|
||||
|
||||
// Simple INI Reader with basic functions. Case insensitive.
|
||||
class IniReader
|
||||
{
|
||||
public:
|
||||
typedef std::map<std::string, std::string> ParameterValueMap;
|
||||
|
||||
IniReader();
|
||||
|
||||
~IniReader();
|
||||
|
||||
ErrorCode LoadIniFile(const std::string& p_iniFilePath);
|
||||
|
||||
bool DoesSectionExist(const std::string& p_section) const;
|
||||
|
||||
bool DoesParameterExist(const std::string& p_section, const std::string& p_param) const;
|
||||
|
||||
const ParameterValueMap& GetParameters(const std::string& p_section) const;
|
||||
|
||||
template <typename DataType>
|
||||
DataType GetParameter(const std::string& p_section, const std::string& p_param, const DataType& p_defaultVal) const;
|
||||
|
||||
void SetParameter(const std::string& p_section, const std::string& p_param, const std::string& p_val);
|
||||
|
||||
private:
|
||||
bool GetRawValue(const std::string& p_section, const std::string& p_param, std::string& p_value) const;
|
||||
|
||||
template <typename DataType>
|
||||
static inline DataType ConvertStringTo(std::string&& p_str, const DataType& p_defaultVal);
|
||||
|
||||
private:
|
||||
const static ParameterValueMap c_emptyParameters;
|
||||
|
||||
std::map<std::string, std::shared_ptr<ParameterValueMap>> m_parameters;
|
||||
};
|
||||
|
||||
|
||||
template <typename DataType>
|
||||
DataType
|
||||
IniReader::GetParameter(const std::string& p_section, const std::string& p_param, const DataType& p_defaultVal) const
|
||||
{
|
||||
std::string value;
|
||||
if (!GetRawValue(p_section, p_param, value))
|
||||
{
|
||||
return p_defaultVal;
|
||||
}
|
||||
|
||||
return ConvertStringTo<DataType>(std::move(value), p_defaultVal);
|
||||
}
|
||||
|
||||
|
||||
template <typename DataType>
|
||||
inline DataType
|
||||
IniReader::ConvertStringTo(std::string&& p_str, const DataType& p_defaultVal)
|
||||
{
|
||||
DataType value;
|
||||
if (Convert::ConvertStringTo<DataType>(p_str.c_str(), value))
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
return p_defaultVal;
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline std::string
|
||||
IniReader::ConvertStringTo<std::string>(std::string&& p_str, const std::string& p_defaultVal)
|
||||
{
|
||||
return std::move(p_str);
|
||||
}
|
||||
|
||||
|
||||
} // namespace Helper
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_HELPER_INIREADER_H_
|
@ -1,374 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_HELPER_STRINGCONVERTHELPER_H_
|
||||
#define _SPTAG_HELPER_STRINGCONVERTHELPER_H_
|
||||
|
||||
#include "inc/Core/Common.h"
|
||||
#include "CommonHelper.h"
|
||||
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
#include <sstream>
|
||||
#include <cctype>
|
||||
#include <limits>
|
||||
#include <cerrno>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Helper
|
||||
{
|
||||
namespace Convert
|
||||
{
|
||||
|
||||
template <typename DataType>
|
||||
inline bool ConvertStringTo(const char* p_str, DataType& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
std::istringstream sstream;
|
||||
sstream.str(p_str);
|
||||
if (p_str >> p_value)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
template<typename DataType>
|
||||
inline std::string ConvertToString(const DataType& p_value)
|
||||
{
|
||||
return std::to_string(p_value);
|
||||
}
|
||||
|
||||
|
||||
// Specialization of ConvertStringTo<>().
|
||||
|
||||
template <typename DataType>
|
||||
inline bool ConvertStringToSignedInt(const char* p_str, DataType& p_value)
|
||||
{
|
||||
static_assert(std::is_integral<DataType>::value && std::is_signed<DataType>::value, "type check");
|
||||
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
char* end = nullptr;
|
||||
errno = 0;
|
||||
auto val = std::strtoll(p_str, &end, 10);
|
||||
if (errno == ERANGE || end == p_str || *end != '\0')
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (val < (std::numeric_limits<DataType>::min)() || val >(std::numeric_limits<DataType>::max)())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
p_value = static_cast<DataType>(val);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template <typename DataType>
|
||||
inline bool ConvertStringToUnsignedInt(const char* p_str, DataType& p_value)
|
||||
{
|
||||
static_assert(std::is_integral<DataType>::value && std::is_unsigned<DataType>::value, "type check");
|
||||
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
char* end = nullptr;
|
||||
errno = 0;
|
||||
auto val = std::strtoull(p_str, &end, 10);
|
||||
if (errno == ERANGE || end == p_str || *end != '\0')
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (val < (std::numeric_limits<DataType>::min)() || val >(std::numeric_limits<DataType>::max)())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
p_value = static_cast<DataType>(val);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::string>(const char* p_str, std::string& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
p_value = p_str;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<float>(const char* p_str, float& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
char* end = nullptr;
|
||||
errno = 0;
|
||||
p_value = std::strtof(p_str, &end);
|
||||
return (errno != ERANGE && end != p_str && *end == '\0');
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<double>(const char* p_str, double& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
char* end = nullptr;
|
||||
errno = 0;
|
||||
p_value = std::strtod(p_str, &end);
|
||||
return (errno != ERANGE && end != p_str && *end == '\0');
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::int8_t>(const char* p_str, std::int8_t& p_value)
|
||||
{
|
||||
return ConvertStringToSignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::int16_t>(const char* p_str, std::int16_t& p_value)
|
||||
{
|
||||
return ConvertStringToSignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::int32_t>(const char* p_str, std::int32_t& p_value)
|
||||
{
|
||||
return ConvertStringToSignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::int64_t>(const char* p_str, std::int64_t& p_value)
|
||||
{
|
||||
return ConvertStringToSignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::uint8_t>(const char* p_str, std::uint8_t& p_value)
|
||||
{
|
||||
return ConvertStringToUnsignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::uint16_t>(const char* p_str, std::uint16_t& p_value)
|
||||
{
|
||||
return ConvertStringToUnsignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::uint32_t>(const char* p_str, std::uint32_t& p_value)
|
||||
{
|
||||
return ConvertStringToUnsignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<std::uint64_t>(const char* p_str, std::uint64_t& p_value)
|
||||
{
|
||||
return ConvertStringToUnsignedInt(p_str, p_value);
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<bool>(const char* p_str, bool& p_value)
|
||||
{
|
||||
if (StrUtils::StrEqualIgnoreCase(p_str, "true"))
|
||||
{
|
||||
p_value = true;
|
||||
|
||||
}
|
||||
else if (StrUtils::StrEqualIgnoreCase(p_str, "false"))
|
||||
{
|
||||
p_value = false;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<IndexAlgoType>(const char* p_str, IndexAlgoType& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#define DefineIndexAlgo(Name) \
|
||||
else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \
|
||||
{ \
|
||||
p_value = IndexAlgoType::Name; \
|
||||
return true; \
|
||||
} \
|
||||
|
||||
#include "inc/Core/DefinitionList.h"
|
||||
#undef DefineIndexAlgo
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<DistCalcMethod>(const char* p_str, DistCalcMethod& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#define DefineDistCalcMethod(Name) \
|
||||
else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \
|
||||
{ \
|
||||
p_value = DistCalcMethod::Name; \
|
||||
return true; \
|
||||
} \
|
||||
|
||||
#include "inc/Core/DefinitionList.h"
|
||||
#undef DefineDistCalcMethod
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline bool ConvertStringTo<VectorValueType>(const char* p_str, VectorValueType& p_value)
|
||||
{
|
||||
if (nullptr == p_str)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
#define DefineVectorValueType(Name, Type) \
|
||||
else if (StrUtils::StrEqualIgnoreCase(p_str, #Name)) \
|
||||
{ \
|
||||
p_value = VectorValueType::Name; \
|
||||
return true; \
|
||||
} \
|
||||
|
||||
#include "inc/Core/DefinitionList.h"
|
||||
#undef DefineVectorValueType
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
// Specialization of ConvertToString<>().
|
||||
|
||||
template<>
|
||||
inline std::string ConvertToString<std::string>(const std::string& p_value)
|
||||
{
|
||||
return p_value;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline std::string ConvertToString<bool>(const bool& p_value)
|
||||
{
|
||||
return p_value ? "true" : "false";
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline std::string ConvertToString<IndexAlgoType>(const IndexAlgoType& p_value)
|
||||
{
|
||||
switch (p_value)
|
||||
{
|
||||
#define DefineIndexAlgo(Name) \
|
||||
case IndexAlgoType::Name: \
|
||||
return #Name; \
|
||||
|
||||
#include "inc/Core/DefinitionList.h"
|
||||
#undef DefineIndexAlgo
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return "Undefined";
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline std::string ConvertToString<DistCalcMethod>(const DistCalcMethod& p_value)
|
||||
{
|
||||
switch (p_value)
|
||||
{
|
||||
#define DefineDistCalcMethod(Name) \
|
||||
case DistCalcMethod::Name: \
|
||||
return #Name; \
|
||||
|
||||
#include "inc/Core/DefinitionList.h"
|
||||
#undef DefineDistCalcMethod
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return "Undefined";
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
inline std::string ConvertToString<VectorValueType>(const VectorValueType& p_value)
|
||||
{
|
||||
switch (p_value)
|
||||
{
|
||||
#define DefineVectorValueType(Name, Type) \
|
||||
case VectorValueType::Name: \
|
||||
return #Name; \
|
||||
|
||||
#include "inc/Core/DefinitionList.h"
|
||||
#undef DefineVectorValueType
|
||||
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
return "Undefined";
|
||||
}
|
||||
|
||||
|
||||
} // namespace Convert
|
||||
} // namespace Helper
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_HELPER_STRINGCONVERTHELPER_H_
|
@ -1,47 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_INDEXBUILDER_OPTIONS_H_
|
||||
#define _SPTAG_INDEXBUILDER_OPTIONS_H_
|
||||
|
||||
#include "inc/Core/Common.h"
|
||||
#include "inc/Helper/ArgumentsParser.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace IndexBuilder
|
||||
{
|
||||
|
||||
class BuilderOptions : public Helper::ArgumentsParser
|
||||
{
|
||||
public:
|
||||
BuilderOptions();
|
||||
|
||||
~BuilderOptions();
|
||||
|
||||
std::uint32_t m_threadNum;
|
||||
|
||||
std::uint32_t m_dimension;
|
||||
|
||||
std::string m_vectorDelimiter;
|
||||
|
||||
SPTAG::VectorValueType m_inputValueType;
|
||||
|
||||
std::string m_inputFiles;
|
||||
|
||||
std::string m_outputFolder;
|
||||
|
||||
SPTAG::IndexAlgoType m_indexAlgoType;
|
||||
|
||||
std::string m_builderConfigFile;
|
||||
};
|
||||
|
||||
|
||||
} // namespace IndexBuilder
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_INDEXBUILDER_OPTIONS_H_
|
@ -1,27 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_INDEXBUILDER_THREADPOOL_H_
|
||||
#define _SPTAG_INDEXBUILDER_THREADPOOL_H_
|
||||
|
||||
#include <functional>
|
||||
#include <cstdint>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace IndexBuilder
|
||||
{
|
||||
namespace ThreadPool
|
||||
{
|
||||
|
||||
void Init(std::uint32_t p_threadNum);
|
||||
|
||||
bool Queue(std::function<void()> p_workItem);
|
||||
|
||||
std::uint32_t CurrentThreadNum();
|
||||
|
||||
}
|
||||
} // namespace IndexBuilder
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_INDEXBUILDER_THREADPOOL_H_
|
@ -1,43 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_INDEXBUILDER_VECTORSETREADER_H_
|
||||
#define _SPTAG_INDEXBUILDER_VECTORSETREADER_H_
|
||||
|
||||
#include "inc/Core/Common.h"
|
||||
#include "inc/Core/VectorSet.h"
|
||||
#include "inc/Core/MetadataSet.h"
|
||||
#include "Options.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace IndexBuilder
|
||||
{
|
||||
|
||||
class VectorSetReader
|
||||
{
|
||||
public:
|
||||
VectorSetReader(std::shared_ptr<BuilderOptions> p_options);
|
||||
|
||||
virtual ~VectorSetReader();
|
||||
|
||||
virtual ErrorCode LoadFile(const std::string& p_filePath) = 0;
|
||||
|
||||
virtual std::shared_ptr<VectorSet> GetVectorSet() const = 0;
|
||||
|
||||
virtual std::shared_ptr<MetadataSet> GetMetadataSet() const = 0;
|
||||
|
||||
static std::shared_ptr<VectorSetReader> CreateInstance(std::shared_ptr<BuilderOptions> p_options);
|
||||
|
||||
protected:
|
||||
std::shared_ptr<BuilderOptions> m_options;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace IndexBuilder
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_INDEXBUILDER_VECTORSETREADER_H_
|
@ -1,108 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULTREADER_H_
|
||||
#define _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULTREADER_H_
|
||||
|
||||
#include "../VectorSetReader.h"
|
||||
#include "inc/Helper/Concurrent.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <condition_variable>
|
||||
#include <mutex>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace IndexBuilder
|
||||
{
|
||||
|
||||
class DefaultReader : public VectorSetReader
|
||||
{
|
||||
public:
|
||||
DefaultReader(std::shared_ptr<BuilderOptions> p_options);
|
||||
|
||||
virtual ~DefaultReader();
|
||||
|
||||
virtual ErrorCode LoadFile(const std::string& p_filePaths);
|
||||
|
||||
virtual std::shared_ptr<VectorSet> GetVectorSet() const;
|
||||
|
||||
virtual std::shared_ptr<MetadataSet> GetMetadataSet() const;
|
||||
|
||||
private:
|
||||
typedef std::pair<std::string, std::size_t> FileInfoPair;
|
||||
|
||||
static std::vector<FileInfoPair> GetFileSizes(const std::string& p_filePaths);
|
||||
|
||||
void LoadFileInternal(const std::string& p_filePath,
|
||||
std::uint32_t p_subtaskID,
|
||||
std::uint32_t p_fileBlockID,
|
||||
std::size_t p_fileBlockSize);
|
||||
|
||||
void MergeData();
|
||||
|
||||
template<typename DataType>
|
||||
bool TranslateVector(char* p_str, DataType* p_vector)
|
||||
{
|
||||
std::uint32_t eleCount = 0;
|
||||
char* next = p_str;
|
||||
while ((*next) != '\0')
|
||||
{
|
||||
while ((*next) != '\0' && m_options->m_vectorDelimiter.find(*next) == std::string::npos)
|
||||
{
|
||||
++next;
|
||||
}
|
||||
|
||||
bool reachEnd = ('\0' == (*next));
|
||||
*next = '\0';
|
||||
if (p_str != next)
|
||||
{
|
||||
if (eleCount >= m_options->m_dimension)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!Helper::Convert::ConvertStringTo(p_str, p_vector[eleCount++]))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (reachEnd)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
++next;
|
||||
p_str = next;
|
||||
}
|
||||
|
||||
return eleCount == m_options->m_dimension;
|
||||
}
|
||||
|
||||
private:
|
||||
std::uint32_t m_subTaskCount;
|
||||
|
||||
std::size_t m_subTaskBlocksize;
|
||||
|
||||
std::atomic<std::uint32_t> m_totalRecordCount;
|
||||
|
||||
std::atomic<std::size_t> m_totalRecordVectorBytes;
|
||||
|
||||
std::vector<std::uint32_t> m_subTaskRecordCount;
|
||||
|
||||
std::string m_vectorOutput;
|
||||
|
||||
std::string m_metadataConentOutput;
|
||||
|
||||
std::string m_metadataIndexOutput;
|
||||
|
||||
Helper::Concurrent::WaitSignal m_waitSignal;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace IndexBuilder
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_INDEXBUILDER_VECTORSETREADERS_DEFAULT_H_
|
@ -1,56 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SERVER_QUERYPARSER_H_
|
||||
#define _SPTAG_SERVER_QUERYPARSER_H_
|
||||
|
||||
#include "../Core/Common.h"
|
||||
#include "../Core/CommonDataStructure.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Service
|
||||
{
|
||||
|
||||
|
||||
class QueryParser
|
||||
{
|
||||
public:
|
||||
typedef std::pair<const char*, const char*> OptionPair;
|
||||
|
||||
QueryParser();
|
||||
|
||||
~QueryParser();
|
||||
|
||||
ErrorCode Parse(const std::string& p_query, const char* p_vectorSeparator);
|
||||
|
||||
const std::vector<const char*>& GetVectorElements() const;
|
||||
|
||||
const std::vector<OptionPair>& GetOptions() const;
|
||||
|
||||
const char* GetVectorBase64() const;
|
||||
|
||||
SizeType GetVectorBase64Length() const;
|
||||
|
||||
private:
|
||||
std::vector<OptionPair> m_options;
|
||||
|
||||
std::vector<const char*> m_vectorElements;
|
||||
|
||||
const char* m_vectorBase64;
|
||||
|
||||
SizeType m_vectorBase64Length;
|
||||
|
||||
ByteArray m_dataHolder;
|
||||
|
||||
static const char* c_defaultVectorSeparator;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Server
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_SERVER_QUERYPARSER_H_
|
@ -1,80 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_
|
||||
#define _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_
|
||||
|
||||
#include "inc/Core/VectorIndex.h"
|
||||
#include "inc/Core/SearchQuery.h"
|
||||
#include "inc/Socket/RemoteSearchQuery.h"
|
||||
#include "ServiceSettings.h"
|
||||
#include "QueryParser.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Service
|
||||
{
|
||||
|
||||
typedef Socket::IndexSearchResult SearchResult;
|
||||
|
||||
class SearchExecutionContext
|
||||
{
|
||||
public:
|
||||
SearchExecutionContext(const std::shared_ptr<const ServiceSettings>& p_serviceSettings);
|
||||
|
||||
~SearchExecutionContext();
|
||||
|
||||
ErrorCode ParseQuery(const std::string& p_query);
|
||||
|
||||
ErrorCode ExtractOption();
|
||||
|
||||
ErrorCode ExtractVector(VectorValueType p_targetType);
|
||||
|
||||
void AddResults(std::string p_indexName, QueryResult& p_results);
|
||||
|
||||
std::vector<SearchResult>& GetResults();
|
||||
|
||||
const std::vector<SearchResult>& GetResults() const;
|
||||
|
||||
const ByteArray& GetVector() const;
|
||||
|
||||
const std::vector<std::string>& GetSelectedIndexNames() const;
|
||||
|
||||
const SizeType GetVectorDimension() const;
|
||||
|
||||
const std::vector<QueryParser::OptionPair>& GetOptions() const;
|
||||
|
||||
const SizeType GetResultNum() const;
|
||||
|
||||
const bool GetExtractMetadata() const;
|
||||
|
||||
private:
|
||||
const std::shared_ptr<const ServiceSettings> c_serviceSettings;
|
||||
|
||||
QueryParser m_queryParser;
|
||||
|
||||
std::vector<std::string> m_indexNames;
|
||||
|
||||
ByteArray m_vector;
|
||||
|
||||
SizeType m_vectorDimension;
|
||||
|
||||
std::vector<SearchResult> m_results;
|
||||
|
||||
VectorValueType m_inputValueType;
|
||||
|
||||
bool m_extractMetadata;
|
||||
|
||||
SizeType m_resultNum;
|
||||
};
|
||||
|
||||
} // namespace Server
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_SERVER_SEARCHEXECUTIONCONTEXT_H_
|
@ -1,56 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SERVER_SEARCHEXECUTOR_H_
|
||||
#define _SPTAG_SERVER_SEARCHEXECUTOR_H_
|
||||
|
||||
#include "ServiceContext.h"
|
||||
#include "ServiceSettings.h"
|
||||
#include "SearchExecutionContext.h"
|
||||
#include "QueryParser.h"
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Service
|
||||
{
|
||||
|
||||
class SearchExecutor
|
||||
{
|
||||
public:
|
||||
typedef std::function<void(std::shared_ptr<SearchExecutionContext>)> CallBack;
|
||||
|
||||
SearchExecutor(std::string p_queryString,
|
||||
std::shared_ptr<ServiceContext> p_serviceContext,
|
||||
const CallBack& p_callback);
|
||||
|
||||
~SearchExecutor();
|
||||
|
||||
void Execute();
|
||||
|
||||
private:
|
||||
void ExecuteInternal();
|
||||
|
||||
void SelectIndex();
|
||||
|
||||
private:
|
||||
CallBack m_callback;
|
||||
|
||||
const std::shared_ptr<ServiceContext> c_serviceContext;
|
||||
|
||||
std::shared_ptr<SearchExecutionContext> m_executionContext;
|
||||
|
||||
std::string m_queryString;
|
||||
|
||||
std::vector<std::shared_ptr<VectorIndex>> m_selectedIndex;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Server
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_SERVER_SEARCHEXECUTOR_H_
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SERVER_SERVICE_H_
|
||||
#define _SPTAG_SERVER_SERVICE_H_
|
||||
|
||||
#include "ServiceContext.h"
|
||||
#include "../Socket/Server.h"
|
||||
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <thread>
|
||||
#include <condition_variable>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Service
|
||||
{
|
||||
|
||||
class SearchExecutionContext;
|
||||
|
||||
class SearchService
|
||||
{
|
||||
public:
|
||||
SearchService();
|
||||
|
||||
~SearchService();
|
||||
|
||||
bool Initialize(int p_argNum, char* p_args[]);
|
||||
|
||||
void Run();
|
||||
|
||||
private:
|
||||
void RunSocketMode();
|
||||
|
||||
void RunInteractiveMode();
|
||||
|
||||
void SearchHanlder(Socket::ConnectionID p_localConnectionID, Socket::Packet p_packet);
|
||||
|
||||
void SearchHanlderCallback(std::shared_ptr<SearchExecutionContext> p_exeContext,
|
||||
Socket::Packet p_srcPacket);
|
||||
|
||||
private:
|
||||
enum class ServeMode : std::uint8_t
|
||||
{
|
||||
Interactive,
|
||||
|
||||
Socket
|
||||
};
|
||||
|
||||
std::shared_ptr<ServiceContext> m_serviceContext;
|
||||
|
||||
std::shared_ptr<Socket::Server> m_socketServer;
|
||||
|
||||
bool m_initialized;
|
||||
|
||||
ServeMode m_serveMode;
|
||||
|
||||
std::unique_ptr<boost::asio::thread_pool> m_threadPool;
|
||||
|
||||
boost::asio::io_context m_ioContext;
|
||||
|
||||
boost::asio::signal_set m_shutdownSignals;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Server
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_SERVER_SERVICE_H_
|
@ -1,44 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SERVER_SERVICECONTEX_H_
|
||||
#define _SPTAG_SERVER_SERVICECONTEX_H_
|
||||
|
||||
#include "inc/Core/VectorIndex.h"
|
||||
#include "ServiceSettings.h"
|
||||
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Service
|
||||
{
|
||||
|
||||
class ServiceContext
|
||||
{
|
||||
public:
|
||||
ServiceContext(const std::string& p_configFilePath);
|
||||
|
||||
~ServiceContext();
|
||||
|
||||
const std::map<std::string, std::shared_ptr<VectorIndex>>& GetIndexMap() const;
|
||||
|
||||
const std::shared_ptr<ServiceSettings>& GetServiceSettings() const;
|
||||
|
||||
bool IsInitialized() const;
|
||||
|
||||
private:
|
||||
bool m_initialized;
|
||||
|
||||
std::shared_ptr<ServiceSettings> m_settings;
|
||||
|
||||
std::map<std::string, std::shared_ptr<VectorIndex>> m_fullIndexList;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Server
|
||||
} // namespace AnnService
|
||||
|
||||
#endif // _SPTAG_SERVER_SERVICECONTEX_H_
|
||||
|
@ -1,41 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SERVER_SERVICESTTINGS_H_
|
||||
#define _SPTAG_SERVER_SERVICESTTINGS_H_
|
||||
|
||||
#include "../Core/Common.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Service
|
||||
{
|
||||
|
||||
struct ServiceSettings
|
||||
{
|
||||
ServiceSettings();
|
||||
|
||||
std::string m_vectorSeparator;
|
||||
|
||||
std::string m_listenAddr;
|
||||
|
||||
std::string m_listenPort;
|
||||
|
||||
SizeType m_defaultMaxResultNumber;
|
||||
|
||||
SizeType m_threadNum;
|
||||
|
||||
SizeType m_socketThreadNum;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
} // namespace Server
|
||||
} // namespace AnnService
|
||||
|
||||
|
||||
#endif // _SPTAG_SERVER_SERVICESTTINGS_H_
|
||||
|
@ -1,68 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_CLIENT_H_
|
||||
#define _SPTAG_SOCKET_CLIENT_H_
|
||||
|
||||
#include "inc/Core/Common.h"
|
||||
#include "Connection.h"
|
||||
#include "ConnectionManager.h"
|
||||
#include "Packet.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
class Client
|
||||
{
|
||||
public:
|
||||
typedef std::function<void(ConnectionID p_cid, SPTAG::ErrorCode)> ConnectCallback;
|
||||
|
||||
Client(const PacketHandlerMapPtr& p_handlerMap,
|
||||
std::size_t p_threadNum,
|
||||
std::uint32_t p_heartbeatIntervalSeconds);
|
||||
|
||||
~Client();
|
||||
|
||||
ConnectionID ConnectToServer(const std::string& p_address,
|
||||
const std::string& p_port,
|
||||
SPTAG::ErrorCode& p_ec);
|
||||
|
||||
void AsyncConnectToServer(const std::string& p_address,
|
||||
const std::string& p_port,
|
||||
ConnectCallback p_callback);
|
||||
|
||||
void SendPacket(ConnectionID p_connection, Packet p_packet, std::function<void(bool)> p_callback);
|
||||
|
||||
void SetEventOnConnectionClose(std::function<void(ConnectionID)> p_event);
|
||||
|
||||
private:
|
||||
void KeepIoContext();
|
||||
|
||||
private:
|
||||
std::atomic_bool m_stopped;
|
||||
|
||||
std::uint32_t m_heartbeatIntervalSeconds;
|
||||
|
||||
boost::asio::io_context m_ioContext;
|
||||
|
||||
boost::asio::deadline_timer m_deadlineTimer;
|
||||
|
||||
std::shared_ptr<ConnectionManager> m_connectionManager;
|
||||
|
||||
std::vector<std::thread> m_threadPool;
|
||||
|
||||
const PacketHandlerMapPtr c_requestHandlerMap;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SOCKET_CLIENT_H_
|
@ -1,25 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_COMMON_H_
|
||||
#define _SPTAG_SOCKET_COMMON_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
typedef std::uint32_t ConnectionID;
|
||||
|
||||
typedef std::uint32_t ResourceID;
|
||||
|
||||
extern const ConnectionID c_invalidConnectionID;
|
||||
|
||||
extern const ResourceID c_invalidResourceID;
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SOCKET_COMMON_H_
|
@ -1,100 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_CONNECTION_H_
|
||||
#define _SPTAG_SOCKET_CONNECTION_H_
|
||||
|
||||
#include "Packet.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
#include <boost/asio/io_service_strand.hpp>
|
||||
#include <boost/asio/deadline_timer.hpp>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
class ConnectionManager;
|
||||
|
||||
class Connection : public std::enable_shared_from_this<Connection>
|
||||
{
|
||||
public:
|
||||
typedef std::shared_ptr<Connection> Ptr;
|
||||
|
||||
Connection(ConnectionID p_connectionID,
|
||||
boost::asio::ip::tcp::socket&& p_socket,
|
||||
const PacketHandlerMapPtr& p_handlerMap,
|
||||
std::weak_ptr<ConnectionManager> p_connectionManager);
|
||||
|
||||
void Start();
|
||||
|
||||
void Stop();
|
||||
|
||||
void StartHeartbeat(std::size_t p_intervalSeconds);
|
||||
|
||||
void AsyncSend(Packet p_packet, std::function<void(bool)> p_callback);
|
||||
|
||||
ConnectionID GetConnectionID() const;
|
||||
|
||||
ConnectionID GetRemoteConnectionID() const;
|
||||
|
||||
Connection(const Connection&) = delete;
|
||||
Connection& operator=(const Connection&) = delete;
|
||||
|
||||
private:
|
||||
void AsyncReadHeader();
|
||||
|
||||
void AsyncReadBody();
|
||||
|
||||
void HandleReadHeader(boost::system::error_code p_ec, std::size_t p_bytesTransferred);
|
||||
|
||||
void HandleReadBody(boost::system::error_code p_ec, std::size_t p_bytesTransferred);
|
||||
|
||||
void SendHeartbeat(std::size_t p_intervalSeconds);
|
||||
|
||||
void SendRegister();
|
||||
|
||||
void HandleHeartbeatRequest();
|
||||
|
||||
void HandleRegisterRequest();
|
||||
|
||||
void HandleRegisterResponse();
|
||||
|
||||
void HandleNoHandlerResponse();
|
||||
|
||||
void OnConnectionFail(const boost::system::error_code& p_ec);
|
||||
|
||||
private:
|
||||
const ConnectionID c_connectionID;
|
||||
|
||||
ConnectionID m_remoteConnectionID;
|
||||
|
||||
const std::weak_ptr<ConnectionManager> c_connectionManager;
|
||||
|
||||
const PacketHandlerMapPtr c_handlerMap;
|
||||
|
||||
boost::asio::ip::tcp::socket m_socket;
|
||||
|
||||
boost::asio::io_context::strand m_strand;
|
||||
|
||||
boost::asio::deadline_timer m_heartbeatTimer;
|
||||
|
||||
std::uint8_t m_packetHeaderReadBuffer[PacketHeader::c_bufferSize];
|
||||
|
||||
Packet m_packetRead;
|
||||
|
||||
std::atomic_bool m_stopped;
|
||||
|
||||
std::atomic_bool m_heartbeatStarted;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SOCKET_CONNECTION_H_
|
@ -1,73 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_CONNECTIONMANAGER_H_
|
||||
#define _SPTAG_SOCKET_CONNECTIONMANAGER_H_
|
||||
|
||||
#include "Connection.h"
|
||||
#include "inc/Helper/Concurrent.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <array>
|
||||
|
||||
#include <boost/asio/ip/tcp.hpp>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
class ConnectionManager : public std::enable_shared_from_this<ConnectionManager>
|
||||
{
|
||||
public:
|
||||
ConnectionManager();
|
||||
|
||||
ConnectionID AddConnection(boost::asio::ip::tcp::socket&& p_socket,
|
||||
const PacketHandlerMapPtr& p_handlerMap,
|
||||
std::uint32_t p_heartbeatIntervalSeconds);
|
||||
|
||||
void RemoveConnection(ConnectionID p_connectionID);
|
||||
|
||||
Connection::Ptr GetConnection(ConnectionID p_connectionID);
|
||||
|
||||
void SetEventOnRemoving(std::function<void(ConnectionID)> p_event);
|
||||
|
||||
void StopAll();
|
||||
|
||||
private:
|
||||
inline static std::uint32_t GetPosition(ConnectionID p_connectionID);
|
||||
|
||||
private:
|
||||
static constexpr std::uint32_t c_connectionPoolSize = 1 << 8;
|
||||
|
||||
static constexpr std::uint32_t c_connectionPoolMask = c_connectionPoolSize - 1;
|
||||
|
||||
struct ConnectionItem
|
||||
{
|
||||
ConnectionItem();
|
||||
|
||||
std::atomic_bool m_isEmpty;
|
||||
|
||||
Connection::Ptr m_connection;
|
||||
};
|
||||
|
||||
// Start from 1. 0 means not assigned.
|
||||
std::atomic<ConnectionID> m_nextConnectionID;
|
||||
|
||||
std::atomic<std::uint32_t> m_connectionCount;
|
||||
|
||||
std::array<ConnectionItem, c_connectionPoolSize> m_connections;
|
||||
|
||||
Helper::Concurrent::SpinLock m_spinLock;
|
||||
|
||||
std::function<void(ConnectionID)> m_eventOnRemoving;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SOCKET_CONNECTIONMANAGER_H_
|
@ -1,142 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_PACKET_H_
|
||||
#define _SPTAG_SOCKET_PACKET_H_
|
||||
|
||||
#include "Common.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <array>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
enum class PacketType : std::uint8_t
|
||||
{
|
||||
Undefined = 0x00,
|
||||
|
||||
HeartbeatRequest = 0x01,
|
||||
|
||||
RegisterRequest = 0x02,
|
||||
|
||||
SearchRequest = 0x03,
|
||||
|
||||
ResponseMask = 0x80,
|
||||
|
||||
HeartbeatResponse = ResponseMask | HeartbeatRequest,
|
||||
|
||||
RegisterResponse = ResponseMask | RegisterRequest,
|
||||
|
||||
SearchResponse = ResponseMask | SearchRequest
|
||||
};
|
||||
|
||||
|
||||
enum class PacketProcessStatus : std::uint8_t
|
||||
{
|
||||
Ok = 0x00,
|
||||
|
||||
Timeout = 0x01,
|
||||
|
||||
Dropped = 0x02,
|
||||
|
||||
Failed = 0x03
|
||||
};
|
||||
|
||||
|
||||
struct PacketHeader
|
||||
{
|
||||
static constexpr std::size_t c_bufferSize = 16;
|
||||
|
||||
PacketHeader();
|
||||
PacketHeader(PacketHeader&& p_right);
|
||||
PacketHeader(const PacketHeader& p_right);
|
||||
|
||||
std::size_t WriteBuffer(std::uint8_t* p_buffer);
|
||||
|
||||
void ReadBuffer(const std::uint8_t* p_buffer);
|
||||
|
||||
PacketType m_packetType;
|
||||
|
||||
PacketProcessStatus m_processStatus;
|
||||
|
||||
std::uint32_t m_bodyLength;
|
||||
|
||||
// Meaning of this is different with different PacketType.
|
||||
// In most request case, it means connection expeced for response.
|
||||
// In most response case, it means connection which handled request.
|
||||
ConnectionID m_connectionID;
|
||||
|
||||
ResourceID m_resourceID;
|
||||
};
|
||||
|
||||
|
||||
static_assert(sizeof(PacketHeader) <= PacketHeader::c_bufferSize, "");
|
||||
|
||||
|
||||
class Packet
|
||||
{
|
||||
public:
|
||||
Packet();
|
||||
Packet(Packet&& p_right);
|
||||
Packet(const Packet& p_right);
|
||||
|
||||
PacketHeader& Header();
|
||||
|
||||
std::uint8_t* HeaderBuffer() const;
|
||||
|
||||
std::uint8_t* Body() const;
|
||||
|
||||
std::uint8_t* Buffer() const;
|
||||
|
||||
std::uint32_t BufferLength() const;
|
||||
|
||||
std::uint32_t BufferCapacity() const;
|
||||
|
||||
void AllocateBuffer(std::uint32_t p_bodyCapacity);
|
||||
|
||||
private:
|
||||
PacketHeader m_header;
|
||||
|
||||
std::shared_ptr<std::uint8_t> m_buffer;
|
||||
|
||||
std::uint32_t m_bufferCapacity;
|
||||
};
|
||||
|
||||
|
||||
struct PacketTypeHash
|
||||
{
|
||||
std::size_t operator()(const PacketType& p_val) const
|
||||
{
|
||||
return static_cast<std::size_t>(p_val);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
typedef std::function<void(ConnectionID, Packet)> PacketHandler;
|
||||
|
||||
typedef std::unordered_map<PacketType, PacketHandler, PacketTypeHash> PacketHandlerMap;
|
||||
typedef std::shared_ptr<PacketHandlerMap> PacketHandlerMapPtr;
|
||||
|
||||
|
||||
namespace PacketTypeHelper
|
||||
{
|
||||
|
||||
bool IsRequestPacket(PacketType p_type);
|
||||
|
||||
bool IsResponsePacket(PacketType p_type);
|
||||
|
||||
PacketType GetCrosspondingResponseType(PacketType p_type);
|
||||
|
||||
}
|
||||
|
||||
|
||||
} // namespace SPTAG
|
||||
} // namespace Socket
|
||||
|
||||
#endif // _SPTAG_SOCKET_SOCKETSERVER_H_
|
@ -1,99 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_REMOTESEARCHQUERY_H_
|
||||
#define _SPTAG_SOCKET_REMOTESEARCHQUERY_H_
|
||||
|
||||
#include "inc/Core/CommonDataStructure.h"
|
||||
#include "inc/Core/SearchQuery.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
// TODO: use Bond replace below structures.
|
||||
|
||||
struct RemoteQuery
|
||||
{
|
||||
static constexpr std::uint16_t MajorVersion() { return 1; }
|
||||
static constexpr std::uint16_t MirrorVersion() { return 0; }
|
||||
|
||||
enum class QueryType : std::uint8_t
|
||||
{
|
||||
String = 0
|
||||
};
|
||||
|
||||
RemoteQuery();
|
||||
|
||||
std::size_t EstimateBufferSize() const;
|
||||
|
||||
std::uint8_t* Write(std::uint8_t* p_buffer) const;
|
||||
|
||||
const std::uint8_t* Read(const std::uint8_t* p_buffer);
|
||||
|
||||
|
||||
QueryType m_type;
|
||||
|
||||
std::string m_queryString;
|
||||
};
|
||||
|
||||
|
||||
struct IndexSearchResult
|
||||
{
|
||||
std::string m_indexName;
|
||||
|
||||
QueryResult m_results;
|
||||
};
|
||||
|
||||
|
||||
struct RemoteSearchResult
|
||||
{
|
||||
static constexpr std::uint16_t MajorVersion() { return 1; }
|
||||
static constexpr std::uint16_t MirrorVersion() { return 0; }
|
||||
|
||||
enum class ResultStatus : std::uint8_t
|
||||
{
|
||||
Success = 0,
|
||||
|
||||
Timeout = 1,
|
||||
|
||||
FailedNetwork = 2,
|
||||
|
||||
FailedExecute = 3,
|
||||
|
||||
Dropped = 4
|
||||
};
|
||||
|
||||
RemoteSearchResult();
|
||||
|
||||
RemoteSearchResult(const RemoteSearchResult& p_right);
|
||||
|
||||
RemoteSearchResult(RemoteSearchResult&& p_right);
|
||||
|
||||
RemoteSearchResult& operator=(RemoteSearchResult&& p_right);
|
||||
|
||||
std::size_t EstimateBufferSize() const;
|
||||
|
||||
std::uint8_t* Write(std::uint8_t* p_buffer) const;
|
||||
|
||||
const std::uint8_t* Read(const std::uint8_t* p_buffer);
|
||||
|
||||
|
||||
ResultStatus m_status;
|
||||
|
||||
std::vector<IndexSearchResult> m_allIndexResults;
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // namespace SPTAG
|
||||
} // namespace Socket
|
||||
|
||||
#endif // _SPTAG_SOCKET_REMOTESEARCHQUERY_H_
|
@ -1,190 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_RESOURCEMANAGER_H_
|
||||
#define _SPTAG_SOCKET_RESOURCEMANAGER_H_
|
||||
|
||||
#include "Common.h"
|
||||
|
||||
#include <boost/asio/io_context.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <deque>
|
||||
#include <unordered_map>
|
||||
#include <thread>
|
||||
|
||||
namespace std
|
||||
{
|
||||
typedef atomic<uint32_t> atomic_uint32_t;
|
||||
}
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
template<typename ResourceType>
|
||||
class ResourceManager : public std::enable_shared_from_this<ResourceManager<ResourceType>>
|
||||
{
|
||||
public:
|
||||
typedef std::function<void(std::shared_ptr<ResourceType>)> TimeoutCallback;
|
||||
|
||||
ResourceManager()
|
||||
: m_nextResourceID(1),
|
||||
m_isStopped(false),
|
||||
m_timeoutItemCount(0)
|
||||
{
|
||||
m_timeoutChecker = std::thread(&ResourceManager::StartCheckTimeout, this);
|
||||
}
|
||||
|
||||
|
||||
~ResourceManager()
|
||||
{
|
||||
m_isStopped = true;
|
||||
m_timeoutChecker.join();
|
||||
}
|
||||
|
||||
|
||||
ResourceID Add(const std::shared_ptr<ResourceType>& p_resource,
|
||||
std::uint32_t p_timeoutMilliseconds,
|
||||
TimeoutCallback p_timeoutCallback)
|
||||
{
|
||||
ResourceID rid = m_nextResourceID.fetch_add(1);
|
||||
while (c_invalidResourceID == rid)
|
||||
{
|
||||
rid = m_nextResourceID.fetch_add(1);
|
||||
}
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(m_resourcesMutex);
|
||||
m_resources.emplace(rid, p_resource);
|
||||
}
|
||||
|
||||
if (p_timeoutMilliseconds > 0)
|
||||
{
|
||||
std::unique_ptr<ResourceItem> item(new ResourceItem);
|
||||
|
||||
item->m_resourceID = rid;
|
||||
item->m_callback = std::move(p_timeoutCallback);
|
||||
item->m_expireTime = m_clock.now() + std::chrono::milliseconds(p_timeoutMilliseconds);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(m_timeoutListMutex);
|
||||
m_timeoutList.emplace_back(std::move(item));
|
||||
}
|
||||
|
||||
++m_timeoutItemCount;
|
||||
}
|
||||
|
||||
return rid;
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<ResourceType> GetAndRemove(ResourceID p_resourceID)
|
||||
{
|
||||
std::shared_ptr<ResourceType> ret;
|
||||
std::lock_guard<std::mutex> guard(m_resourcesMutex);
|
||||
auto iter = m_resources.find(p_resourceID);
|
||||
if (iter != m_resources.end())
|
||||
{
|
||||
ret = iter->second;
|
||||
m_resources.erase(iter);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
void Remove(ResourceID p_resourceID)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(m_resourcesMutex);
|
||||
auto iter = m_resources.find(p_resourceID);
|
||||
if (iter != m_resources.end())
|
||||
{
|
||||
m_resources.erase(iter);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void StartCheckTimeout()
|
||||
{
|
||||
std::vector<std::unique_ptr<ResourceItem>> timeouted;
|
||||
timeouted.reserve(1024);
|
||||
while (!m_isStopped)
|
||||
{
|
||||
if (m_timeoutItemCount > 0)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(m_timeoutListMutex);
|
||||
while (!m_timeoutList.empty()
|
||||
&& m_timeoutList.front()->m_expireTime <= m_clock.now())
|
||||
{
|
||||
timeouted.emplace_back(std::move(m_timeoutList.front()));
|
||||
m_timeoutList.pop_front();
|
||||
--m_timeoutItemCount;
|
||||
}
|
||||
}
|
||||
|
||||
if (timeouted.empty())
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1));
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto& item : timeouted)
|
||||
{
|
||||
auto resource = GetAndRemove(item->m_resourceID);
|
||||
if (nullptr != resource)
|
||||
{
|
||||
item->m_callback(std::move(resource));
|
||||
}
|
||||
}
|
||||
|
||||
timeouted.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
struct ResourceItem
|
||||
{
|
||||
ResourceItem()
|
||||
: m_resourceID(c_invalidResourceID)
|
||||
{
|
||||
}
|
||||
|
||||
ResourceID m_resourceID;
|
||||
|
||||
TimeoutCallback m_callback;
|
||||
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> m_expireTime;
|
||||
};
|
||||
|
||||
std::deque<std::unique_ptr<ResourceItem>> m_timeoutList;
|
||||
|
||||
std::atomic<std::uint32_t> m_timeoutItemCount;
|
||||
|
||||
std::mutex m_timeoutListMutex;
|
||||
|
||||
std::unordered_map<ResourceID, std::shared_ptr<ResourceType>> m_resources;
|
||||
|
||||
std::atomic<ResourceID> m_nextResourceID;
|
||||
|
||||
std::mutex m_resourcesMutex;
|
||||
|
||||
std::chrono::high_resolution_clock m_clock;
|
||||
|
||||
std::thread m_timeoutChecker;
|
||||
|
||||
bool m_isStopped;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SOCKET_RESOURCEMANAGER_H_
|
@ -1,55 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_SERVER_H_
|
||||
#define _SPTAG_SOCKET_SERVER_H_
|
||||
|
||||
#include "Connection.h"
|
||||
#include "ConnectionManager.h"
|
||||
#include "Packet.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <boost/asio.hpp>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
|
||||
class Server
|
||||
{
|
||||
public:
|
||||
Server(const std::string& p_address,
|
||||
const std::string& p_port,
|
||||
const PacketHandlerMapPtr& p_handlerMap,
|
||||
std::size_t p_threadNum);
|
||||
|
||||
~Server();
|
||||
|
||||
void StartListen();
|
||||
|
||||
void SendPacket(ConnectionID p_connection, Packet p_packet, std::function<void(bool)> p_callback);
|
||||
|
||||
void SetEventOnConnectionClose(std::function<void(ConnectionID)> p_event);
|
||||
|
||||
private:
|
||||
void StartAccept();
|
||||
|
||||
private:
|
||||
boost::asio::io_context m_ioContext;
|
||||
|
||||
boost::asio::ip::tcp::acceptor m_acceptor;
|
||||
|
||||
std::shared_ptr<ConnectionManager> m_connectionManager;
|
||||
|
||||
std::vector<std::thread> m_threadPool;
|
||||
|
||||
const PacketHandlerMapPtr m_requestHandlerMap;
|
||||
};
|
||||
|
||||
|
||||
} // namespace Socket
|
||||
} // namespace SPTAG
|
||||
|
||||
#endif // _SPTAG_SOCKET_SERVER_H_
|
@ -1,174 +0,0 @@
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#ifndef _SPTAG_SOCKET_SIMPLESERIALIZATION_H_
|
||||
#define _SPTAG_SOCKET_SIMPLESERIALIZATION_H_
|
||||
|
||||
#include "inc/Core/CommonDataStructure.h"
|
||||
|
||||
#include <type_traits>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <cstring>
|
||||
|
||||
namespace SPTAG
|
||||
{
|
||||
namespace Socket
|
||||
{
|
||||
namespace SimpleSerialization
|
||||
{
|
||||
|
||||
template<typename T>
|
||||
inline std::uint8_t*
|
||||
SimpleWriteBuffer(const T& p_val, std::uint8_t* p_buffer)
|
||||
{
|
||||
static_assert(std::is_fundamental<T>::value || std::is_enum<T>::value,
|
||||
"Only applied for fundanmental type.");
|
||||
|
||||
*(reinterpret_cast<T*>(p_buffer)) = p_val;
|
||||
return p_buffer + sizeof(T);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
inline const std::uint8_t*
|
||||
SimpleReadBuffer(const std::uint8_t* p_buffer, T& p_val)
|
||||
{
|
||||
static_assert(std::is_fundamental<T>::value || std::is_enum<T>::value,
|
||||
"Only applied for fundanmental type.");
|
||||
|
||||
p_val = *(reinterpret_cast<const T*>(p_buffer));
|
||||
return p_buffer + sizeof(T);
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
inline std::size_t
|
||||
EstimateBufferSize(const T& p_val)
|
||||
{
|
||||
static_assert(std::is_fundamental<T>::value || std::is_enum<T>::value,
|
||||
"Only applied for fundanmental type.");
|
||||
|
||||
return sizeof(T);
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline std::uint8_t*
|
||||
SimpleWriteBuffer<std::string>(const std::string& p_val, std::uint8_t* p_buffer)
|
||||
{
|
||||
p_buffer = SimpleWriteBuffer(static_cast<std::uint32_t>(p_val.size()), p_buffer);
|
||||
|
||||
std::memcpy(p_buffer, p_val.c_str(), p_val.size());
|
||||
return p_buffer + p_val.size();
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline const std::uint8_t*
|
||||
SimpleReadBuffer<std::string>(const std::uint8_t* p_buffer, std::string& p_val)
|
||||
{
|
||||
p_val.clear();
|
||||
std::uint32_t len = 0;
|
||||
p_buffer = SimpleReadBuffer(p_buffer, len);
|
||||
|
||||
if (len > 0)
|
||||
{
|
||||
p_val.reserve(len);
|
||||
p_val.assign(reinterpret_cast<const char*>(p_buffer), len);
|
||||
}
|
||||
|
||||
return p_buffer + len;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline std::size_t
|
||||
EstimateBufferSize<std::string>(const std::string& p_val)
|
||||
{
|
||||
return sizeof(std::uint32_t) + p_val.size();
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline std::uint8_t*
|
||||
SimpleWriteBuffer<ByteArray>(const ByteArray& p_val, std::uint8_t* p_buffer)
|
||||
{
|
||||
p_buffer = SimpleWriteBuffer(static_cast<std::uint32_t>(p_val.Length()), p_buffer);
|
||||
|
||||
std::memcpy(p_buffer, p_val.Data(), p_val.Length());
|
||||
return p_buffer + p_val.Length();
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline const std::uint8_t*
|
||||
SimpleReadBuffer<ByteArray>(const std::uint8_t* p_buffer, ByteArray& p_val)
|
||||
{
|
||||
p_val.Clear();
|
||||
std::uint32_t len = 0;
|
||||
p_buffer = SimpleReadBuffer(p_buffer, len);
|
||||
|
||||
if (len > 0)
|
||||
{
|
||||
p_val = ByteArray::Alloc(len);
|
||||
std::memcpy(p_val.Data(), p_buffer, len);
|
||||
}
|
||||
|
||||
return p_buffer + len;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
inline std::size_t
|
||||
EstimateBufferSize<ByteArray>(const ByteArray& p_val)
|
||||
{
|
||||
return sizeof(std::uint32_t) + p_val.Length();
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
inline std::uint8_t*
|
||||
SimpleWriteSharedPtrBuffer(const std::shared_ptr<T>& p_val, std::uint8_t* p_buffer)
|
||||
{
|
||||
if (nullptr == p_val)
|
||||
{
|
||||
return SimpleWriteBuffer(false, p_buffer);
|
||||
}
|
||||
|
||||
p_buffer = SimpleWriteBuffer(true, p_buffer);
|
||||
p_buffer = SimpleWriteBuffer(*p_val, p_buffer);
|
||||
return p_buffer;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
inline const std::uint8_t*
|
||||
SimpleReadSharedPtrBuffer(const std::uint8_t* p_buffer, std::shared_ptr<T>& p_val)
|
||||
{
|
||||
p_val.reset();
|
||||
bool isNotNull = false;
|
||||
p_buffer = SimpleReadBuffer(p_buffer, isNotNull);
|
||||
|
||||
if (isNotNull)
|
||||
{
|
||||
p_val.reset(new T);
|
||||
p_buffer = SimpleReadBuffer(p_buffer, *p_val);
|
||||
}
|
||||
|
||||
return p_buffer;
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
inline std::size_t
|
||||
EstimateSharedPtrBufferSize(const std::shared_ptr<T>& p_val)
|
||||
{
|
||||
return sizeof(bool) + (nullptr == p_val ? 0 : EstimateBufferSize(*p_val));
|
||||
}
|
||||
|
||||
} // namespace SimpleSerialization
|
||||
} // namespace SPTAG
|
||||
} // namespace Socket
|
||||
|
||||
#endif // _SPTAG_SOCKET_SIMPLESERIALIZATION_H_
|
@ -1,151 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_ALLOCATOR_H
|
||||
#define ARROW_ALLOCATOR_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "arrow/memory_pool.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/macros.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
/// \brief A STL allocator delegating allocations to a Arrow MemoryPool
|
||||
template <class T>
|
||||
class stl_allocator {
|
||||
public:
|
||||
using value_type = T;
|
||||
using pointer = T*;
|
||||
using const_pointer = const T*;
|
||||
using reference = T&;
|
||||
using const_reference = const T&;
|
||||
using size_type = std::size_t;
|
||||
using difference_type = std::ptrdiff_t;
|
||||
|
||||
template <class U>
|
||||
struct rebind {
|
||||
using other = stl_allocator<U>;
|
||||
};
|
||||
|
||||
/// \brief Construct an allocator from the default MemoryPool
|
||||
stl_allocator() noexcept : pool_(default_memory_pool()) {}
|
||||
/// \brief Construct an allocator from the given MemoryPool
|
||||
explicit stl_allocator(MemoryPool* pool) noexcept : pool_(pool) {}
|
||||
|
||||
template <class U>
|
||||
stl_allocator(const stl_allocator<U>& rhs) noexcept : pool_(rhs.pool_) {}
|
||||
|
||||
~stl_allocator() { pool_ = NULLPTR; }
|
||||
|
||||
pointer address(reference r) const noexcept { return std::addressof(r); }
|
||||
|
||||
const_pointer address(const_reference r) const noexcept { return std::addressof(r); }
|
||||
|
||||
pointer allocate(size_type n, const void* /*hint*/ = NULLPTR) {
|
||||
uint8_t* data;
|
||||
Status s = pool_->Allocate(n * sizeof(T), &data);
|
||||
if (!s.ok()) throw std::bad_alloc();
|
||||
return reinterpret_cast<pointer>(data);
|
||||
}
|
||||
|
||||
void deallocate(pointer p, size_type n) {
|
||||
pool_->Free(reinterpret_cast<uint8_t*>(p), n * sizeof(T));
|
||||
}
|
||||
|
||||
size_type size_max() const noexcept { return size_type(-1) / sizeof(T); }
|
||||
|
||||
template <class U, class... Args>
|
||||
void construct(U* p, Args&&... args) {
|
||||
new (reinterpret_cast<void*>(p)) U(std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void destroy(U* p) {
|
||||
p->~U();
|
||||
}
|
||||
|
||||
MemoryPool* pool() const noexcept { return pool_; }
|
||||
|
||||
private:
|
||||
MemoryPool* pool_;
|
||||
};
|
||||
|
||||
/// \brief A MemoryPool implementation delegating allocations to a STL allocator
|
||||
///
|
||||
/// Note that STL allocators don't provide a resizing operation, and therefore
|
||||
/// any buffer resizes will do a full reallocation and copy.
|
||||
template <typename Allocator = std::allocator<uint8_t>>
|
||||
class STLMemoryPool : public MemoryPool {
|
||||
public:
|
||||
/// \brief Construct a memory pool from the given allocator
|
||||
explicit STLMemoryPool(const Allocator& alloc) : alloc_(alloc) {}
|
||||
|
||||
Status Allocate(int64_t size, uint8_t** out) override {
|
||||
try {
|
||||
*out = alloc_.allocate(size);
|
||||
} catch (std::bad_alloc& e) {
|
||||
return Status::OutOfMemory(e.what());
|
||||
}
|
||||
stats_.UpdateAllocatedBytes(size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Reallocate(int64_t old_size, int64_t new_size, uint8_t** ptr) override {
|
||||
uint8_t* old_ptr = *ptr;
|
||||
try {
|
||||
*ptr = alloc_.allocate(new_size);
|
||||
} catch (std::bad_alloc& e) {
|
||||
return Status::OutOfMemory(e.what());
|
||||
}
|
||||
memcpy(*ptr, old_ptr, std::min(old_size, new_size));
|
||||
alloc_.deallocate(old_ptr, old_size);
|
||||
stats_.UpdateAllocatedBytes(new_size - old_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Free(uint8_t* buffer, int64_t size) override {
|
||||
alloc_.deallocate(buffer, size);
|
||||
stats_.UpdateAllocatedBytes(-size);
|
||||
}
|
||||
|
||||
int64_t bytes_allocated() const override { return stats_.bytes_allocated(); }
|
||||
|
||||
int64_t max_memory() const override { return stats_.max_memory(); }
|
||||
|
||||
private:
|
||||
Allocator alloc_;
|
||||
internal::MemoryPoolStats stats_;
|
||||
};
|
||||
|
||||
template <class T1, class T2>
|
||||
bool operator==(const stl_allocator<T1>& lhs, const stl_allocator<T2>& rhs) noexcept {
|
||||
return lhs.pool() == rhs.pool();
|
||||
}
|
||||
|
||||
template <class T1, class T2>
|
||||
bool operator!=(const stl_allocator<T1>& lhs, const stl_allocator<T2>& rhs) noexcept {
|
||||
return !(lhs == rhs);
|
||||
}
|
||||
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_ALLOCATOR_H
|
@ -1,43 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Coarse public API while the library is in development
|
||||
|
||||
#ifndef ARROW_API_H
|
||||
#define ARROW_API_H
|
||||
|
||||
#include "arrow/array.h" // IYWU pragma: export
|
||||
#include "arrow/buffer.h" // IYWU pragma: export
|
||||
#include "arrow/builder.h" // IYWU pragma: export
|
||||
#include "arrow/compare.h" // IYWU pragma: export
|
||||
#include "arrow/extension_type.h" // IYWU pragma: export
|
||||
#include "arrow/memory_pool.h" // IYWU pragma: export
|
||||
#include "arrow/pretty_print.h" // IYWU pragma: export
|
||||
#include "arrow/record_batch.h" // IYWU pragma: export
|
||||
#include "arrow/status.h" // IYWU pragma: export
|
||||
#include "arrow/table.h" // IYWU pragma: export
|
||||
#include "arrow/table_builder.h" // IYWU pragma: export
|
||||
#include "arrow/tensor.h" // IYWU pragma: export
|
||||
#include "arrow/type.h" // IYWU pragma: export
|
||||
#include "arrow/util/config.h" // IYWU pragma: export
|
||||
#include "arrow/util/key_value_metadata.h" // IWYU pragma: export
|
||||
#include "arrow/visitor.h" // IYWU pragma: export
|
||||
|
||||
/// \brief Top-level namespace for Apache Arrow C++ API
|
||||
namespace arrow {}
|
||||
|
||||
#endif // ARROW_API_H
|
1149
cpp/thirdparty/knowhere_build/include/arrow/array.h
vendored
1149
cpp/thirdparty/knowhere_build/include/arrow/array.h
vendored
File diff suppressed because it is too large
Load Diff
@ -1,175 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/array/builder_base.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
namespace internal {
|
||||
|
||||
class ARROW_EXPORT AdaptiveIntBuilderBase : public ArrayBuilder {
|
||||
public:
|
||||
explicit AdaptiveIntBuilderBase(MemoryPool* pool);
|
||||
|
||||
/// \brief Append multiple nulls
|
||||
/// \param[in] length the number of nulls to append
|
||||
Status AppendNulls(int64_t length) final {
|
||||
ARROW_RETURN_NOT_OK(CommitPendingData());
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
memset(data_->mutable_data() + length_ * int_size_, 0, int_size_ * length);
|
||||
UnsafeSetNull(length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendNull() final {
|
||||
pending_data_[pending_pos_] = 0;
|
||||
pending_valid_[pending_pos_] = 0;
|
||||
pending_has_nulls_ = true;
|
||||
++pending_pos_;
|
||||
|
||||
if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
|
||||
return CommitPendingData();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Reset() override;
|
||||
Status Resize(int64_t capacity) override;
|
||||
|
||||
protected:
|
||||
virtual Status CommitPendingData() = 0;
|
||||
|
||||
std::shared_ptr<ResizableBuffer> data_;
|
||||
uint8_t* raw_data_;
|
||||
uint8_t int_size_;
|
||||
|
||||
static constexpr int32_t pending_size_ = 1024;
|
||||
uint8_t pending_valid_[pending_size_];
|
||||
uint64_t pending_data_[pending_size_];
|
||||
int32_t pending_pos_;
|
||||
bool pending_has_nulls_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
class ARROW_EXPORT AdaptiveUIntBuilder : public internal::AdaptiveIntBuilderBase {
|
||||
public:
|
||||
explicit AdaptiveUIntBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
using ArrayBuilder::Advance;
|
||||
using internal::AdaptiveIntBuilderBase::Reset;
|
||||
|
||||
/// Scalar append
|
||||
Status Append(const uint64_t val) {
|
||||
pending_data_[pending_pos_] = val;
|
||||
pending_valid_[pending_pos_] = 1;
|
||||
++pending_pos_;
|
||||
|
||||
if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
|
||||
return CommitPendingData();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a contiguous C array of values
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
|
||||
/// indicates a valid (non-null) value
|
||||
/// \return Status
|
||||
Status AppendValues(const uint64_t* values, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
protected:
|
||||
Status CommitPendingData() override;
|
||||
Status ExpandIntSize(uint8_t new_int_size);
|
||||
|
||||
Status AppendValuesInternal(const uint64_t* values, int64_t length,
|
||||
const uint8_t* valid_bytes);
|
||||
|
||||
template <typename new_type, typename old_type>
|
||||
typename std::enable_if<sizeof(old_type) >= sizeof(new_type), Status>::type
|
||||
ExpandIntSizeInternal();
|
||||
#define __LESS(a, b) (a) < (b)
|
||||
template <typename new_type, typename old_type>
|
||||
typename std::enable_if<__LESS(sizeof(old_type), sizeof(new_type)), Status>::type
|
||||
ExpandIntSizeInternal();
|
||||
#undef __LESS
|
||||
|
||||
template <typename new_type>
|
||||
Status ExpandIntSizeN();
|
||||
};
|
||||
|
||||
class ARROW_EXPORT AdaptiveIntBuilder : public internal::AdaptiveIntBuilderBase {
|
||||
public:
|
||||
explicit AdaptiveIntBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
using ArrayBuilder::Advance;
|
||||
using internal::AdaptiveIntBuilderBase::Reset;
|
||||
|
||||
/// Scalar append
|
||||
Status Append(const int64_t val) {
|
||||
auto v = static_cast<uint64_t>(val);
|
||||
|
||||
pending_data_[pending_pos_] = v;
|
||||
pending_valid_[pending_pos_] = 1;
|
||||
++pending_pos_;
|
||||
|
||||
if (ARROW_PREDICT_FALSE(pending_pos_ >= pending_size_)) {
|
||||
return CommitPendingData();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a contiguous C array of values
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
|
||||
/// indicates a valid (non-null) value
|
||||
/// \return Status
|
||||
Status AppendValues(const int64_t* values, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
protected:
|
||||
Status CommitPendingData() override;
|
||||
Status ExpandIntSize(uint8_t new_int_size);
|
||||
|
||||
Status AppendValuesInternal(const int64_t* values, int64_t length,
|
||||
const uint8_t* valid_bytes);
|
||||
|
||||
template <typename new_type, typename old_type>
|
||||
typename std::enable_if<sizeof(old_type) >= sizeof(new_type), Status>::type
|
||||
ExpandIntSizeInternal();
|
||||
#define __LESS(a, b) (a) < (b)
|
||||
template <typename new_type, typename old_type>
|
||||
typename std::enable_if<__LESS(sizeof(old_type), sizeof(new_type)), Status>::type
|
||||
ExpandIntSizeInternal();
|
||||
#undef __LESS
|
||||
|
||||
template <typename new_type>
|
||||
Status ExpandIntSizeN();
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,219 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm> // IWYU pragma: keep
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/buffer-builder.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/type.h"
|
||||
#include "arrow/type_traits.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/type_traits.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
struct ArrayData;
|
||||
class MemoryPool;
|
||||
|
||||
constexpr int64_t kMinBuilderCapacity = 1 << 5;
|
||||
constexpr int64_t kListMaximumElements = std::numeric_limits<int32_t>::max() - 1;
|
||||
|
||||
/// Base class for all data array builders.
|
||||
///
|
||||
/// This class provides a facilities for incrementally building the null bitmap
|
||||
/// (see Append methods) and as a side effect the current number of slots and
|
||||
/// the null count.
|
||||
///
|
||||
/// \note Users are expected to use builders as one of the concrete types below.
|
||||
/// For example, ArrayBuilder* pointing to BinaryBuilder should be downcast before use.
|
||||
class ARROW_EXPORT ArrayBuilder {
|
||||
public:
|
||||
explicit ArrayBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
|
||||
: type_(type), pool_(pool), null_bitmap_builder_(pool) {}
|
||||
|
||||
virtual ~ArrayBuilder() = default;
|
||||
|
||||
/// For nested types. Since the objects are owned by this class instance, we
|
||||
/// skip shared pointers and just return a raw pointer
|
||||
ArrayBuilder* child(int i) { return children_[i].get(); }
|
||||
|
||||
int num_children() const { return static_cast<int>(children_.size()); }
|
||||
|
||||
int64_t length() const { return length_; }
|
||||
int64_t null_count() const { return null_count_; }
|
||||
int64_t capacity() const { return capacity_; }
|
||||
|
||||
/// \brief Ensure that enough memory has been allocated to fit the indicated
|
||||
/// number of total elements in the builder, including any that have already
|
||||
/// been appended. Does not account for reallocations that may be due to
|
||||
/// variable size data, like binary values. To make space for incremental
|
||||
/// appends, use Reserve instead.
|
||||
///
|
||||
/// \param[in] capacity the minimum number of total array values to
|
||||
/// accommodate. Must be greater than the current capacity.
|
||||
/// \return Status
|
||||
virtual Status Resize(int64_t capacity);
|
||||
|
||||
/// \brief Ensure that there is enough space allocated to add the indicated
|
||||
/// number of elements without any further calls to Resize. Overallocation is
|
||||
/// used in order to minimize the impact of incremental Reserve() calls.
|
||||
///
|
||||
/// \param[in] additional_capacity the number of additional array values
|
||||
/// \return Status
|
||||
Status Reserve(int64_t additional_capacity) {
|
||||
auto current_capacity = capacity();
|
||||
auto min_capacity = length() + additional_capacity;
|
||||
if (min_capacity <= current_capacity) return Status::OK();
|
||||
|
||||
// leave growth factor up to BufferBuilder
|
||||
auto new_capacity = BufferBuilder::GrowByFactor(current_capacity, min_capacity);
|
||||
return Resize(new_capacity);
|
||||
}
|
||||
|
||||
/// Reset the builder.
|
||||
virtual void Reset();
|
||||
|
||||
virtual Status AppendNull() = 0;
|
||||
virtual Status AppendNulls(int64_t length) = 0;
|
||||
|
||||
/// For cases where raw data was memcpy'd into the internal buffers, allows us
|
||||
/// to advance the length of the builder. It is your responsibility to use
|
||||
/// this function responsibly.
|
||||
Status Advance(int64_t elements);
|
||||
|
||||
/// \brief Return result of builder as an internal generic ArrayData
|
||||
/// object. Resets builder except for dictionary builder
|
||||
///
|
||||
/// \param[out] out the finalized ArrayData object
|
||||
/// \return Status
|
||||
virtual Status FinishInternal(std::shared_ptr<ArrayData>* out) = 0;
|
||||
|
||||
/// \brief Return result of builder as an Array object.
|
||||
///
|
||||
/// The builder is reset except for DictionaryBuilder.
|
||||
///
|
||||
/// \param[out] out the finalized Array object
|
||||
/// \return Status
|
||||
Status Finish(std::shared_ptr<Array>* out);
|
||||
|
||||
std::shared_ptr<DataType> type() const { return type_; }
|
||||
|
||||
protected:
|
||||
/// Append to null bitmap
|
||||
Status AppendToBitmap(bool is_valid);
|
||||
|
||||
/// Vector append. Treat each zero byte as a null. If valid_bytes is null
|
||||
/// assume all of length bits are valid.
|
||||
Status AppendToBitmap(const uint8_t* valid_bytes, int64_t length);
|
||||
|
||||
/// Uniform append. Append N times the same validity bit.
|
||||
Status AppendToBitmap(int64_t num_bits, bool value);
|
||||
|
||||
/// Set the next length bits to not null (i.e. valid).
|
||||
Status SetNotNull(int64_t length);
|
||||
|
||||
// Unsafe operations (don't check capacity/don't resize)
|
||||
|
||||
void UnsafeAppendNull() { UnsafeAppendToBitmap(false); }
|
||||
|
||||
// Append to null bitmap, update the length
|
||||
void UnsafeAppendToBitmap(bool is_valid) {
|
||||
null_bitmap_builder_.UnsafeAppend(is_valid);
|
||||
++length_;
|
||||
if (!is_valid) ++null_count_;
|
||||
}
|
||||
|
||||
// Vector append. Treat each zero byte as a nullzero. If valid_bytes is null
|
||||
// assume all of length bits are valid.
|
||||
void UnsafeAppendToBitmap(const uint8_t* valid_bytes, int64_t length) {
|
||||
if (valid_bytes == NULLPTR) {
|
||||
return UnsafeSetNotNull(length);
|
||||
}
|
||||
null_bitmap_builder_.UnsafeAppend(valid_bytes, length);
|
||||
length_ += length;
|
||||
null_count_ = null_bitmap_builder_.false_count();
|
||||
}
|
||||
|
||||
// Append the same validity value a given number of times.
|
||||
void UnsafeAppendToBitmap(const int64_t num_bits, bool value) {
|
||||
if (value) {
|
||||
UnsafeSetNotNull(num_bits);
|
||||
} else {
|
||||
UnsafeSetNull(num_bits);
|
||||
}
|
||||
}
|
||||
|
||||
void UnsafeAppendToBitmap(const std::vector<bool>& is_valid);
|
||||
|
||||
// Set the next validity bits to not null (i.e. valid).
|
||||
void UnsafeSetNotNull(int64_t length);
|
||||
|
||||
// Set the next validity bits to null (i.e. invalid).
|
||||
void UnsafeSetNull(int64_t length);
|
||||
|
||||
static Status TrimBuffer(const int64_t bytes_filled, ResizableBuffer* buffer);
|
||||
|
||||
/// \brief Finish to an array of the specified ArrayType
|
||||
template <typename ArrayType>
|
||||
Status FinishTyped(std::shared_ptr<ArrayType>* out) {
|
||||
std::shared_ptr<Array> out_untyped;
|
||||
ARROW_RETURN_NOT_OK(Finish(&out_untyped));
|
||||
*out = std::static_pointer_cast<ArrayType>(std::move(out_untyped));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static Status CheckCapacity(int64_t new_capacity, int64_t old_capacity) {
|
||||
if (new_capacity < 0) {
|
||||
return Status::Invalid("Resize capacity must be positive");
|
||||
}
|
||||
|
||||
if (new_capacity < old_capacity) {
|
||||
return Status::Invalid("Resize cannot downsize");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::shared_ptr<DataType> type_;
|
||||
MemoryPool* pool_;
|
||||
|
||||
TypedBufferBuilder<bool> null_bitmap_builder_;
|
||||
int64_t null_count_ = 0;
|
||||
|
||||
// Array length, so far. Also, the index of the next element to be added
|
||||
int64_t length_ = 0;
|
||||
int64_t capacity_ = 0;
|
||||
|
||||
// Child value array builders. These are owned by this class
|
||||
std::vector<std::shared_ptr<ArrayBuilder>> children_;
|
||||
|
||||
private:
|
||||
ARROW_DISALLOW_COPY_AND_ASSIGN(ArrayBuilder);
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,395 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/array/builder_base.h"
|
||||
#include "arrow/buffer-builder.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/type_traits.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/string_view.h" // IWYU pragma: export
|
||||
|
||||
namespace arrow {
|
||||
|
||||
constexpr int64_t kBinaryMemoryLimit = std::numeric_limits<int32_t>::max() - 1;
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Binary and String
|
||||
|
||||
/// \class BinaryBuilder
|
||||
/// \brief Builder class for variable-length binary data
|
||||
class ARROW_EXPORT BinaryBuilder : public ArrayBuilder {
|
||||
public:
|
||||
explicit BinaryBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
BinaryBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool);
|
||||
|
||||
Status Append(const uint8_t* value, int32_t length) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
ARROW_RETURN_NOT_OK(AppendNextOffset());
|
||||
// Safety check for UBSAN.
|
||||
if (ARROW_PREDICT_TRUE(length > 0)) {
|
||||
ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length));
|
||||
}
|
||||
|
||||
UnsafeAppendToBitmap(true);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendNulls(int64_t length) final {
|
||||
const int64_t num_bytes = value_data_builder_.length();
|
||||
if (ARROW_PREDICT_FALSE(num_bytes > kBinaryMemoryLimit)) {
|
||||
return AppendOverflow(num_bytes);
|
||||
}
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
for (int64_t i = 0; i < length; ++i) {
|
||||
offsets_builder_.UnsafeAppend(static_cast<int32_t>(num_bytes));
|
||||
}
|
||||
UnsafeAppendToBitmap(length, false);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendNull() final {
|
||||
ARROW_RETURN_NOT_OK(AppendNextOffset());
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppendToBitmap(false);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(const char* value, int32_t length) {
|
||||
return Append(reinterpret_cast<const uint8_t*>(value), length);
|
||||
}
|
||||
|
||||
Status Append(util::string_view value) {
|
||||
return Append(value.data(), static_cast<int32_t>(value.size()));
|
||||
}
|
||||
|
||||
/// \brief Append without checking capacity
|
||||
///
|
||||
/// Offsets and data should have been presized using Reserve() and
|
||||
/// ReserveData(), respectively.
|
||||
void UnsafeAppend(const uint8_t* value, int32_t length) {
|
||||
UnsafeAppendNextOffset();
|
||||
value_data_builder_.UnsafeAppend(value, length);
|
||||
UnsafeAppendToBitmap(true);
|
||||
}
|
||||
|
||||
void UnsafeAppend(const char* value, int32_t length) {
|
||||
UnsafeAppend(reinterpret_cast<const uint8_t*>(value), length);
|
||||
}
|
||||
|
||||
void UnsafeAppend(const std::string& value) {
|
||||
UnsafeAppend(value.c_str(), static_cast<int32_t>(value.size()));
|
||||
}
|
||||
|
||||
void UnsafeAppend(util::string_view value) {
|
||||
UnsafeAppend(value.data(), static_cast<int32_t>(value.size()));
|
||||
}
|
||||
|
||||
void UnsafeAppendNull() {
|
||||
const int64_t num_bytes = value_data_builder_.length();
|
||||
offsets_builder_.UnsafeAppend(static_cast<int32_t>(num_bytes));
|
||||
UnsafeAppendToBitmap(false);
|
||||
}
|
||||
|
||||
void Reset() override;
|
||||
Status Resize(int64_t capacity) override;
|
||||
|
||||
/// \brief Ensures there is enough allocated capacity to append the indicated
|
||||
/// number of bytes to the value data buffer without additional allocations
|
||||
Status ReserveData(int64_t elements);
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<BinaryArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \return size of values buffer so far
|
||||
int64_t value_data_length() const { return value_data_builder_.length(); }
|
||||
/// \return capacity of values buffer
|
||||
int64_t value_data_capacity() const { return value_data_builder_.capacity(); }
|
||||
|
||||
/// Temporary access to a value.
|
||||
///
|
||||
/// This pointer becomes invalid on the next modifying operation.
|
||||
const uint8_t* GetValue(int64_t i, int32_t* out_length) const;
|
||||
|
||||
/// Temporary access to a value.
|
||||
///
|
||||
/// This view becomes invalid on the next modifying operation.
|
||||
util::string_view GetView(int64_t i) const;
|
||||
|
||||
protected:
|
||||
TypedBufferBuilder<int32_t> offsets_builder_;
|
||||
TypedBufferBuilder<uint8_t> value_data_builder_;
|
||||
|
||||
Status AppendOverflow(int64_t num_bytes);
|
||||
|
||||
Status AppendNextOffset() {
|
||||
const int64_t num_bytes = value_data_builder_.length();
|
||||
if (ARROW_PREDICT_FALSE(num_bytes > kBinaryMemoryLimit)) {
|
||||
return AppendOverflow(num_bytes);
|
||||
}
|
||||
return offsets_builder_.Append(static_cast<int32_t>(num_bytes));
|
||||
}
|
||||
|
||||
void UnsafeAppendNextOffset() {
|
||||
const int64_t num_bytes = value_data_builder_.length();
|
||||
offsets_builder_.UnsafeAppend(static_cast<int32_t>(num_bytes));
|
||||
}
|
||||
};
|
||||
|
||||
/// \class StringBuilder
|
||||
/// \brief Builder class for UTF8 strings
|
||||
class ARROW_EXPORT StringBuilder : public BinaryBuilder {
|
||||
public:
|
||||
using BinaryBuilder::BinaryBuilder;
|
||||
explicit StringBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
using BinaryBuilder::Append;
|
||||
using BinaryBuilder::Reset;
|
||||
using BinaryBuilder::UnsafeAppend;
|
||||
|
||||
/// \brief Append a sequence of strings in one shot.
|
||||
///
|
||||
/// \param[in] values a vector of strings
|
||||
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
|
||||
/// indicates a valid (non-null) value
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<std::string>& values,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
/// \brief Append a sequence of nul-terminated strings in one shot.
|
||||
/// If one of the values is NULL, it is processed as a null
|
||||
/// value even if the corresponding valid_bytes entry is 1.
|
||||
///
|
||||
/// \param[in] values a contiguous C array of nul-terminated char *
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
|
||||
/// indicates a valid (non-null) value
|
||||
/// \return Status
|
||||
Status AppendValues(const char** values, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<StringArray>* out) { return FinishTyped(out); }
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// FixedSizeBinaryBuilder
|
||||
|
||||
class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {
|
||||
public:
|
||||
FixedSizeBinaryBuilder(const std::shared_ptr<DataType>& type,
|
||||
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
Status Append(const uint8_t* value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppend(value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(const char* value) {
|
||||
return Append(reinterpret_cast<const uint8_t*>(value));
|
||||
}
|
||||
|
||||
Status Append(const util::string_view& view) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppend(view);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(const std::string& s) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppend(s);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <size_t NBYTES>
|
||||
Status Append(const std::array<uint8_t, NBYTES>& value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppend(
|
||||
util::string_view(reinterpret_cast<const char*>(value.data()), value.size()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendValues(const uint8_t* data, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
Status AppendNull() final;
|
||||
|
||||
Status AppendNulls(int64_t length) final;
|
||||
|
||||
void UnsafeAppend(const uint8_t* value) {
|
||||
UnsafeAppendToBitmap(true);
|
||||
if (ARROW_PREDICT_TRUE(byte_width_ > 0)) {
|
||||
byte_builder_.UnsafeAppend(value, byte_width_);
|
||||
}
|
||||
}
|
||||
|
||||
void UnsafeAppend(util::string_view value) {
|
||||
#ifndef NDEBUG
|
||||
CheckValueSize(static_cast<size_t>(value.size()));
|
||||
#endif
|
||||
UnsafeAppend(reinterpret_cast<const uint8_t*>(value.data()));
|
||||
}
|
||||
|
||||
void UnsafeAppendNull() {
|
||||
UnsafeAppendToBitmap(false);
|
||||
byte_builder_.UnsafeAdvance(byte_width_);
|
||||
}
|
||||
|
||||
void Reset() override;
|
||||
Status Resize(int64_t capacity) override;
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<FixedSizeBinaryArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \return size of values buffer so far
|
||||
int64_t value_data_length() const { return byte_builder_.length(); }
|
||||
|
||||
int32_t byte_width() const { return byte_width_; }
|
||||
|
||||
/// Temporary access to a value.
|
||||
///
|
||||
/// This pointer becomes invalid on the next modifying operation.
|
||||
const uint8_t* GetValue(int64_t i) const;
|
||||
|
||||
/// Temporary access to a value.
|
||||
///
|
||||
/// This view becomes invalid on the next modifying operation.
|
||||
util::string_view GetView(int64_t i) const;
|
||||
|
||||
protected:
|
||||
int32_t byte_width_;
|
||||
BufferBuilder byte_builder_;
|
||||
|
||||
/// Temporary access to a value.
|
||||
///
|
||||
/// This pointer becomes invalid on the next modifying operation.
|
||||
uint8_t* GetMutableValue(int64_t i) {
|
||||
uint8_t* data_ptr = byte_builder_.mutable_data();
|
||||
return data_ptr + i * byte_width_;
|
||||
}
|
||||
|
||||
#ifndef NDEBUG
|
||||
void CheckValueSize(int64_t size);
|
||||
#endif
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Chunked builders: build a sequence of BinaryArray or StringArray that are
|
||||
// limited to a particular size (to the upper limit of 2GB)
|
||||
|
||||
namespace internal {
|
||||
|
||||
class ARROW_EXPORT ChunkedBinaryBuilder {
|
||||
public:
|
||||
ChunkedBinaryBuilder(int32_t max_chunk_value_length,
|
||||
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
ChunkedBinaryBuilder(int32_t max_chunk_value_length, int32_t max_chunk_length,
|
||||
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
virtual ~ChunkedBinaryBuilder() = default;
|
||||
|
||||
Status Append(const uint8_t* value, int32_t length) {
|
||||
if (ARROW_PREDICT_FALSE(length + builder_->value_data_length() >
|
||||
max_chunk_value_length_)) {
|
||||
if (builder_->value_data_length() == 0) {
|
||||
// The current item is larger than max_chunk_size_;
|
||||
// this chunk will be oversize and hold *only* this item
|
||||
ARROW_RETURN_NOT_OK(builder_->Append(value, length));
|
||||
return NextChunk();
|
||||
}
|
||||
// The current item would cause builder_->value_data_length() to exceed
|
||||
// max_chunk_size_, so finish this chunk and append the current item to the next
|
||||
// chunk
|
||||
ARROW_RETURN_NOT_OK(NextChunk());
|
||||
return Append(value, length);
|
||||
}
|
||||
|
||||
if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) {
|
||||
// The current item would cause builder_->value_data_length() to exceed
|
||||
// max_chunk_size_, so finish this chunk and append the current item to the next
|
||||
// chunk
|
||||
ARROW_RETURN_NOT_OK(NextChunk());
|
||||
}
|
||||
|
||||
return builder_->Append(value, length);
|
||||
}
|
||||
|
||||
Status Append(const util::string_view& value) {
|
||||
return Append(reinterpret_cast<const uint8_t*>(value.data()),
|
||||
static_cast<int32_t>(value.size()));
|
||||
}
|
||||
|
||||
Status AppendNull() {
|
||||
if (ARROW_PREDICT_FALSE(builder_->length() == max_chunk_length_)) {
|
||||
ARROW_RETURN_NOT_OK(NextChunk());
|
||||
}
|
||||
return builder_->AppendNull();
|
||||
}
|
||||
|
||||
Status Reserve(int64_t values);
|
||||
|
||||
virtual Status Finish(ArrayVector* out);
|
||||
|
||||
protected:
|
||||
Status NextChunk();
|
||||
|
||||
// maximum total character data size per chunk
|
||||
int64_t max_chunk_value_length_;
|
||||
|
||||
// maximum elements allowed per chunk
|
||||
int64_t max_chunk_length_ = kListMaximumElements;
|
||||
|
||||
// when Reserve() would cause builder_ to exceed its max_chunk_length_,
|
||||
// add to extra_capacity_ instead and wait to reserve until the next chunk
|
||||
int64_t extra_capacity_ = 0;
|
||||
|
||||
std::unique_ptr<BinaryBuilder> builder_;
|
||||
std::vector<std::shared_ptr<Array>> chunks_;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT ChunkedStringBuilder : public ChunkedBinaryBuilder {
|
||||
public:
|
||||
using ChunkedBinaryBuilder::ChunkedBinaryBuilder;
|
||||
|
||||
Status Finish(ArrayVector* out) override;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
} // namespace arrow
|
@ -1,53 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/array/builder_base.h"
|
||||
#include "arrow/array/builder_binary.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Decimal128;
|
||||
|
||||
class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder {
|
||||
public:
|
||||
explicit Decimal128Builder(const std::shared_ptr<DataType>& type,
|
||||
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
using FixedSizeBinaryBuilder::Append;
|
||||
using FixedSizeBinaryBuilder::AppendValues;
|
||||
using FixedSizeBinaryBuilder::Reset;
|
||||
|
||||
Status Append(Decimal128 val);
|
||||
void UnsafeAppend(Decimal128 val);
|
||||
void UnsafeAppend(util::string_view val);
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<Decimal128Array>* out) { return FinishTyped(out); }
|
||||
};
|
||||
|
||||
using DecimalBuilder = Decimal128Builder;
|
||||
|
||||
} // namespace arrow
|
@ -1,369 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/array/builder_adaptive.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_base.h" // IWYU pragma: export
|
||||
|
||||
#include "arrow/array.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Dictionary builder
|
||||
|
||||
namespace internal {
|
||||
|
||||
template <typename T>
|
||||
struct DictionaryScalar {
|
||||
using type = typename T::c_type;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DictionaryScalar<BinaryType> {
|
||||
using type = util::string_view;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DictionaryScalar<StringType> {
|
||||
using type = util::string_view;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DictionaryScalar<FixedSizeBinaryType> {
|
||||
using type = util::string_view;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT DictionaryMemoTable {
|
||||
public:
|
||||
explicit DictionaryMemoTable(const std::shared_ptr<DataType>& type);
|
||||
explicit DictionaryMemoTable(const std::shared_ptr<Array>& dictionary);
|
||||
~DictionaryMemoTable();
|
||||
|
||||
int32_t GetOrInsert(const bool& value);
|
||||
int32_t GetOrInsert(const int8_t& value);
|
||||
int32_t GetOrInsert(const int16_t& value);
|
||||
int32_t GetOrInsert(const int32_t& value);
|
||||
int32_t GetOrInsert(const int64_t& value);
|
||||
int32_t GetOrInsert(const uint8_t& value);
|
||||
int32_t GetOrInsert(const uint16_t& value);
|
||||
int32_t GetOrInsert(const uint32_t& value);
|
||||
int32_t GetOrInsert(const uint64_t& value);
|
||||
int32_t GetOrInsert(const float& value);
|
||||
int32_t GetOrInsert(const double& value);
|
||||
int32_t GetOrInsert(const util::string_view& value);
|
||||
|
||||
Status GetArrayData(MemoryPool* pool, int64_t start_offset,
|
||||
std::shared_ptr<ArrayData>* out);
|
||||
|
||||
int32_t size() const;
|
||||
|
||||
private:
|
||||
class DictionaryMemoTableImpl;
|
||||
std::unique_ptr<DictionaryMemoTableImpl> impl_;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
/// \brief Array builder for created encoded DictionaryArray from
|
||||
/// dense array
|
||||
///
|
||||
/// Unlike other builders, dictionary builder does not completely
|
||||
/// reset the state on Finish calls. The arrays built after the
|
||||
/// initial Finish call will reuse the previously created encoding and
|
||||
/// build a delta dictionary when new terms occur.
|
||||
///
|
||||
/// data
|
||||
template <typename T>
|
||||
class DictionaryBuilder : public ArrayBuilder {
|
||||
public:
|
||||
using Scalar = typename internal::DictionaryScalar<T>::type;
|
||||
|
||||
// WARNING: the type given below is the value type, not the DictionaryType.
|
||||
// The DictionaryType is instantiated on the Finish() call.
|
||||
template <typename T1 = T>
|
||||
DictionaryBuilder(
|
||||
typename std::enable_if<!std::is_base_of<FixedSizeBinaryType, T1>::value,
|
||||
const std::shared_ptr<DataType>&>::type type,
|
||||
MemoryPool* pool)
|
||||
: ArrayBuilder(type, pool),
|
||||
memo_table_(new internal::DictionaryMemoTable(type)),
|
||||
delta_offset_(0),
|
||||
byte_width_(-1),
|
||||
values_builder_(pool) {}
|
||||
|
||||
template <typename T1 = T>
|
||||
explicit DictionaryBuilder(
|
||||
typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
|
||||
const std::shared_ptr<DataType>&>::type type,
|
||||
MemoryPool* pool)
|
||||
: ArrayBuilder(type, pool),
|
||||
memo_table_(new internal::DictionaryMemoTable(type)),
|
||||
delta_offset_(0),
|
||||
byte_width_(static_cast<const T1&>(*type).byte_width()),
|
||||
values_builder_(pool) {}
|
||||
|
||||
template <typename T1 = T>
|
||||
explicit DictionaryBuilder(
|
||||
typename std::enable_if<TypeTraits<T1>::is_parameter_free, MemoryPool*>::type pool)
|
||||
: DictionaryBuilder<T1>(TypeTraits<T1>::type_singleton(), pool) {}
|
||||
|
||||
DictionaryBuilder(const std::shared_ptr<Array>& dictionary, MemoryPool* pool)
|
||||
: ArrayBuilder(dictionary->type(), pool),
|
||||
memo_table_(new internal::DictionaryMemoTable(dictionary)),
|
||||
delta_offset_(0),
|
||||
byte_width_(-1),
|
||||
values_builder_(pool) {}
|
||||
|
||||
~DictionaryBuilder() override = default;
|
||||
|
||||
/// \brief Append a scalar value
|
||||
Status Append(const Scalar& value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
|
||||
auto memo_index = memo_table_->GetOrInsert(value);
|
||||
ARROW_RETURN_NOT_OK(values_builder_.Append(memo_index));
|
||||
length_ += 1;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a fixed-width string (only for FixedSizeBinaryType)
|
||||
template <typename T1 = T>
|
||||
Status Append(typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
|
||||
const uint8_t*>::type value) {
|
||||
return Append(util::string_view(reinterpret_cast<const char*>(value), byte_width_));
|
||||
}
|
||||
|
||||
/// \brief Append a fixed-width string (only for FixedSizeBinaryType)
|
||||
template <typename T1 = T>
|
||||
Status Append(typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
|
||||
const char*>::type value) {
|
||||
return Append(util::string_view(value, byte_width_));
|
||||
}
|
||||
|
||||
/// \brief Append a scalar null value
|
||||
Status AppendNull() final {
|
||||
length_ += 1;
|
||||
null_count_ += 1;
|
||||
|
||||
return values_builder_.AppendNull();
|
||||
}
|
||||
|
||||
Status AppendNulls(int64_t length) final {
|
||||
length_ += length;
|
||||
null_count_ += length;
|
||||
|
||||
return values_builder_.AppendNulls(length);
|
||||
}
|
||||
|
||||
/// \brief Append a whole dense array to the builder
|
||||
template <typename T1 = T>
|
||||
Status AppendArray(
|
||||
typename std::enable_if<!std::is_base_of<FixedSizeBinaryType, T1>::value,
|
||||
const Array&>::type array) {
|
||||
using ArrayType = typename TypeTraits<T>::ArrayType;
|
||||
|
||||
const auto& concrete_array = static_cast<const ArrayType&>(array);
|
||||
for (int64_t i = 0; i < array.length(); i++) {
|
||||
if (array.IsNull(i)) {
|
||||
ARROW_RETURN_NOT_OK(AppendNull());
|
||||
} else {
|
||||
ARROW_RETURN_NOT_OK(Append(concrete_array.GetView(i)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T1 = T>
|
||||
Status AppendArray(
|
||||
typename std::enable_if<std::is_base_of<FixedSizeBinaryType, T1>::value,
|
||||
const Array&>::type array) {
|
||||
if (!type_->Equals(*array.type())) {
|
||||
return Status::Invalid(
|
||||
"Cannot append FixedSizeBinary array with non-matching type");
|
||||
}
|
||||
|
||||
const auto& concrete_array = static_cast<const FixedSizeBinaryArray&>(array);
|
||||
for (int64_t i = 0; i < array.length(); i++) {
|
||||
if (array.IsNull(i)) {
|
||||
ARROW_RETURN_NOT_OK(AppendNull());
|
||||
} else {
|
||||
ARROW_RETURN_NOT_OK(Append(concrete_array.GetValue(i)));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
ArrayBuilder::Reset();
|
||||
values_builder_.Reset();
|
||||
memo_table_.reset(new internal::DictionaryMemoTable(type_));
|
||||
delta_offset_ = 0;
|
||||
}
|
||||
|
||||
Status Resize(int64_t capacity) override {
|
||||
ARROW_RETURN_NOT_OK(CheckCapacity(capacity, capacity_));
|
||||
capacity = std::max(capacity, kMinBuilderCapacity);
|
||||
|
||||
if (capacity_ == 0) {
|
||||
// Initialize hash table
|
||||
// XXX should we let the user pass additional size heuristics?
|
||||
delta_offset_ = 0;
|
||||
}
|
||||
ARROW_RETURN_NOT_OK(values_builder_.Resize(capacity));
|
||||
capacity_ = values_builder_.capacity();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
|
||||
// Finalize indices array
|
||||
ARROW_RETURN_NOT_OK(values_builder_.FinishInternal(out));
|
||||
|
||||
// Generate dictionary array from hash table contents
|
||||
std::shared_ptr<ArrayData> dictionary_data;
|
||||
|
||||
ARROW_RETURN_NOT_OK(
|
||||
memo_table_->GetArrayData(pool_, delta_offset_, &dictionary_data));
|
||||
|
||||
// Set type of array data to the right dictionary type
|
||||
(*out)->type = dictionary((*out)->type, type_);
|
||||
(*out)->dictionary = MakeArray(dictionary_data);
|
||||
|
||||
// Update internals for further uses of this DictionaryBuilder
|
||||
delta_offset_ = memo_table_->size();
|
||||
values_builder_.Reset();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<DictionaryArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// is the dictionary builder in the delta building mode
|
||||
bool is_building_delta() { return delta_offset_ > 0; }
|
||||
|
||||
protected:
|
||||
std::unique_ptr<internal::DictionaryMemoTable> memo_table_;
|
||||
|
||||
int32_t delta_offset_;
|
||||
// Only used for FixedSizeBinaryType
|
||||
int32_t byte_width_;
|
||||
|
||||
AdaptiveIntBuilder values_builder_;
|
||||
};
|
||||
|
||||
template <>
|
||||
class DictionaryBuilder<NullType> : public ArrayBuilder {
|
||||
public:
|
||||
DictionaryBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool)
|
||||
: ArrayBuilder(type, pool), values_builder_(pool) {}
|
||||
explicit DictionaryBuilder(MemoryPool* pool)
|
||||
: ArrayBuilder(null(), pool), values_builder_(pool) {}
|
||||
|
||||
DictionaryBuilder(const std::shared_ptr<Array>& dictionary, MemoryPool* pool)
|
||||
: ArrayBuilder(dictionary->type(), pool), values_builder_(pool) {}
|
||||
|
||||
/// \brief Append a scalar null value
|
||||
Status AppendNull() final {
|
||||
length_ += 1;
|
||||
null_count_ += 1;
|
||||
|
||||
return values_builder_.AppendNull();
|
||||
}
|
||||
|
||||
Status AppendNulls(int64_t length) final {
|
||||
length_ += length;
|
||||
null_count_ += length;
|
||||
|
||||
return values_builder_.AppendNulls(length);
|
||||
}
|
||||
|
||||
/// \brief Append a whole dense array to the builder
|
||||
Status AppendArray(const Array& array) {
|
||||
for (int64_t i = 0; i < array.length(); i++) {
|
||||
ARROW_RETURN_NOT_OK(AppendNull());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Resize(int64_t capacity) override {
|
||||
ARROW_RETURN_NOT_OK(CheckCapacity(capacity, capacity_));
|
||||
capacity = std::max(capacity, kMinBuilderCapacity);
|
||||
|
||||
ARROW_RETURN_NOT_OK(values_builder_.Resize(capacity));
|
||||
capacity_ = values_builder_.capacity();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
|
||||
std::shared_ptr<Array> dictionary = std::make_shared<NullArray>(0);
|
||||
|
||||
ARROW_RETURN_NOT_OK(values_builder_.FinishInternal(out));
|
||||
(*out)->type = std::make_shared<DictionaryType>((*out)->type, type_);
|
||||
(*out)->dictionary = dictionary;
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<DictionaryArray>* out) { return FinishTyped(out); }
|
||||
|
||||
protected:
|
||||
AdaptiveIntBuilder values_builder_;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT BinaryDictionaryBuilder : public DictionaryBuilder<BinaryType> {
|
||||
public:
|
||||
using DictionaryBuilder::Append;
|
||||
using DictionaryBuilder::DictionaryBuilder;
|
||||
|
||||
Status Append(const uint8_t* value, int32_t length) {
|
||||
return Append(reinterpret_cast<const char*>(value), length);
|
||||
}
|
||||
|
||||
Status Append(const char* value, int32_t length) {
|
||||
return Append(util::string_view(value, length));
|
||||
}
|
||||
};
|
||||
|
||||
/// \brief Dictionary array builder with convenience methods for strings
|
||||
class ARROW_EXPORT StringDictionaryBuilder : public DictionaryBuilder<StringType> {
|
||||
public:
|
||||
using DictionaryBuilder::Append;
|
||||
using DictionaryBuilder::DictionaryBuilder;
|
||||
|
||||
Status Append(const uint8_t* value, int32_t length) {
|
||||
return Append(reinterpret_cast<const char*>(value), length);
|
||||
}
|
||||
|
||||
Status Append(const char* value, int32_t length) {
|
||||
return Append(util::string_view(value, length));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,260 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/array/builder_base.h"
|
||||
#include "arrow/buffer-builder.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// List builder
|
||||
|
||||
/// \class ListBuilder
|
||||
/// \brief Builder class for variable-length list array value types
|
||||
///
|
||||
/// To use this class, you must append values to the child array builder and use
|
||||
/// the Append function to delimit each distinct list value (once the values
|
||||
/// have been appended to the child array) or use the bulk API to append
|
||||
/// a sequence of offests and null values.
|
||||
///
|
||||
/// A note on types. Per arrow/type.h all types in the c++ implementation are
|
||||
/// logical so even though this class always builds list array, this can
|
||||
/// represent multiple different logical types. If no logical type is provided
|
||||
/// at construction time, the class defaults to List<T> where t is taken from the
|
||||
/// value_builder/values that the object is constructed with.
|
||||
class ARROW_EXPORT ListBuilder : public ArrayBuilder {
|
||||
public:
|
||||
/// Use this constructor to incrementally build the value array along with offsets and
|
||||
/// null bitmap.
|
||||
ListBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& value_builder,
|
||||
const std::shared_ptr<DataType>& type = NULLPTR);
|
||||
|
||||
Status Resize(int64_t capacity) override;
|
||||
void Reset() override;
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<ListArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \brief Vector append
|
||||
///
|
||||
/// If passed, valid_bytes is of equal length to values, and any zero byte
|
||||
/// will be considered as a null for that slot
|
||||
Status AppendValues(const int32_t* offsets, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
/// \brief Start a new variable-length list slot
|
||||
///
|
||||
/// This function should be called before beginning to append elements to the
|
||||
/// value builder
|
||||
Status Append(bool is_valid = true);
|
||||
|
||||
Status AppendNull() final { return Append(false); }
|
||||
|
||||
Status AppendNulls(int64_t length) final;
|
||||
|
||||
ArrayBuilder* value_builder() const;
|
||||
|
||||
protected:
|
||||
TypedBufferBuilder<int32_t> offsets_builder_;
|
||||
std::shared_ptr<ArrayBuilder> value_builder_;
|
||||
std::shared_ptr<Array> values_;
|
||||
|
||||
Status CheckNextOffset() const;
|
||||
Status AppendNextOffset();
|
||||
Status AppendNextOffset(int64_t num_repeats);
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Map builder
|
||||
|
||||
/// \class MapBuilder
|
||||
/// \brief Builder class for arrays of variable-size maps
|
||||
///
|
||||
/// To use this class, you must append values to the key and item array builders
|
||||
/// and use the Append function to delimit each distinct map (once the keys and items
|
||||
/// have been appended) or use the bulk API to append a sequence of offests and null
|
||||
/// maps.
|
||||
///
|
||||
/// Key uniqueness and ordering are not validated.
|
||||
class ARROW_EXPORT MapBuilder : public ArrayBuilder {
|
||||
public:
|
||||
/// Use this constructor to incrementally build the key and item arrays along with
|
||||
/// offsets and null bitmap.
|
||||
MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
|
||||
const std::shared_ptr<ArrayBuilder>& item_builder,
|
||||
const std::shared_ptr<DataType>& type);
|
||||
|
||||
/// Derive built type from key and item builders' types
|
||||
MapBuilder(MemoryPool* pool, const std::shared_ptr<ArrayBuilder>& key_builder,
|
||||
const std::shared_ptr<ArrayBuilder>& item_builder, bool keys_sorted = false);
|
||||
|
||||
Status Resize(int64_t capacity) override;
|
||||
void Reset() override;
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<MapArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \brief Vector append
|
||||
///
|
||||
/// If passed, valid_bytes is of equal length to values, and any zero byte
|
||||
/// will be considered as a null for that slot
|
||||
Status AppendValues(const int32_t* offsets, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
/// \brief Start a new variable-length map slot
|
||||
///
|
||||
/// This function should be called before beginning to append elements to the
|
||||
/// key and value builders
|
||||
Status Append();
|
||||
|
||||
Status AppendNull() final;
|
||||
|
||||
Status AppendNulls(int64_t length) final;
|
||||
|
||||
ArrayBuilder* key_builder() const { return key_builder_.get(); }
|
||||
ArrayBuilder* item_builder() const { return item_builder_.get(); }
|
||||
|
||||
protected:
|
||||
std::shared_ptr<ListBuilder> list_builder_;
|
||||
std::shared_ptr<ArrayBuilder> key_builder_;
|
||||
std::shared_ptr<ArrayBuilder> item_builder_;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// FixedSizeList builder
|
||||
|
||||
/// \class FixedSizeListBuilder
|
||||
/// \brief Builder class for fixed-length list array value types
|
||||
class ARROW_EXPORT FixedSizeListBuilder : public ArrayBuilder {
|
||||
public:
|
||||
FixedSizeListBuilder(MemoryPool* pool,
|
||||
std::shared_ptr<ArrayBuilder> const& value_builder,
|
||||
int32_t list_size);
|
||||
|
||||
FixedSizeListBuilder(MemoryPool* pool,
|
||||
std::shared_ptr<ArrayBuilder> const& value_builder,
|
||||
const std::shared_ptr<DataType>& type);
|
||||
|
||||
Status Resize(int64_t capacity) override;
|
||||
void Reset() override;
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<FixedSizeListArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \brief Append a valid fixed length list.
|
||||
///
|
||||
/// This function affects only the validity bitmap; the child values must be appended
|
||||
/// using the child array builder.
|
||||
Status Append();
|
||||
|
||||
/// \brief Vector append
|
||||
///
|
||||
/// If passed, valid_bytes wil be read and any zero byte
|
||||
/// will cause the corresponding slot to be null
|
||||
///
|
||||
/// This function affects only the validity bitmap; the child values must be appended
|
||||
/// using the child array builder. This includes appending nulls for null lists.
|
||||
/// XXX this restriction is confusing, should this method be omitted?
|
||||
Status AppendValues(int64_t length, const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
/// \brief Append a null fixed length list.
|
||||
///
|
||||
/// The child array builder will have the approriate number of nulls appended
|
||||
/// automatically.
|
||||
Status AppendNull() final;
|
||||
|
||||
/// \brief Append length null fixed length lists.
|
||||
///
|
||||
/// The child array builder will have the approriate number of nulls appended
|
||||
/// automatically.
|
||||
Status AppendNulls(int64_t length) final;
|
||||
|
||||
ArrayBuilder* value_builder() const { return value_builder_.get(); }
|
||||
|
||||
protected:
|
||||
const int32_t list_size_;
|
||||
std::shared_ptr<ArrayBuilder> value_builder_;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Struct
|
||||
|
||||
// ---------------------------------------------------------------------------------
|
||||
// StructArray builder
|
||||
/// Append, Resize and Reserve methods are acting on StructBuilder.
|
||||
/// Please make sure all these methods of all child-builders' are consistently
|
||||
/// called to maintain data-structure consistency.
|
||||
class ARROW_EXPORT StructBuilder : public ArrayBuilder {
|
||||
public:
|
||||
StructBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool,
|
||||
std::vector<std::shared_ptr<ArrayBuilder>>&& field_builders);
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<StructArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// Null bitmap is of equal length to every child field, and any zero byte
|
||||
/// will be considered as a null for that field, but users must using app-
|
||||
/// end methods or advance methods of the child builders' independently to
|
||||
/// insert data.
|
||||
Status AppendValues(int64_t length, const uint8_t* valid_bytes) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
UnsafeAppendToBitmap(valid_bytes, length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Append an element to the Struct. All child-builders' Append method must
|
||||
/// be called independently to maintain data-structure consistency.
|
||||
Status Append(bool is_valid = true) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppendToBitmap(is_valid);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendNull() final { return Append(false); }
|
||||
|
||||
Status AppendNulls(int64_t length) final;
|
||||
|
||||
void Reset() override;
|
||||
|
||||
ArrayBuilder* field_builder(int i) const { return children_[i].get(); }
|
||||
|
||||
int num_fields() const { return static_cast<int>(children_.size()); }
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,429 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/array/builder_base.h"
|
||||
#include "arrow/type.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class ARROW_EXPORT NullBuilder : public ArrayBuilder {
|
||||
public:
|
||||
explicit NullBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
|
||||
: ArrayBuilder(null(), pool) {}
|
||||
|
||||
/// \brief Append the specified number of null elements
|
||||
Status AppendNulls(int64_t length) final {
|
||||
if (length < 0) return Status::Invalid("length must be positive");
|
||||
null_count_ += length;
|
||||
length_ += length;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a single null element
|
||||
Status AppendNull() final { return AppendNulls(1); }
|
||||
|
||||
Status Append(std::nullptr_t) { return AppendNull(); }
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<NullArray>* out) { return FinishTyped(out); }
|
||||
};
|
||||
|
||||
/// Base class for all Builders that emit an Array of a scalar numerical type.
|
||||
template <typename T>
|
||||
class NumericBuilder : public ArrayBuilder {
|
||||
public:
|
||||
using value_type = typename T::c_type;
|
||||
using ArrayType = typename TypeTraits<T>::ArrayType;
|
||||
using ArrayBuilder::ArrayBuilder;
|
||||
|
||||
template <typename T1 = T>
|
||||
explicit NumericBuilder(
|
||||
typename std::enable_if<TypeTraits<T1>::is_parameter_free, MemoryPool*>::type pool
|
||||
ARROW_MEMORY_POOL_DEFAULT)
|
||||
: ArrayBuilder(TypeTraits<T1>::type_singleton(), pool) {}
|
||||
|
||||
/// Append a single scalar and increase the size if necessary.
|
||||
Status Append(const value_type val) {
|
||||
ARROW_RETURN_NOT_OK(ArrayBuilder::Reserve(1));
|
||||
UnsafeAppend(val);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory
|
||||
/// The memory at the corresponding data slot is set to 0 to prevent
|
||||
/// uninitialized memory access
|
||||
Status AppendNulls(int64_t length) final {
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(length, static_cast<value_type>(0));
|
||||
UnsafeSetNull(length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a single null element
|
||||
Status AppendNull() final {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
data_builder_.UnsafeAppend(static_cast<value_type>(0));
|
||||
UnsafeAppendToBitmap(false);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
value_type GetValue(int64_t index) const { return data_builder_.data()[index]; }
|
||||
|
||||
void Reset() override { data_builder_.Reset(); }
|
||||
|
||||
Status Resize(int64_t capacity) override {
|
||||
ARROW_RETURN_NOT_OK(CheckCapacity(capacity, capacity_));
|
||||
capacity = std::max(capacity, kMinBuilderCapacity);
|
||||
ARROW_RETURN_NOT_OK(data_builder_.Resize(capacity));
|
||||
return ArrayBuilder::Resize(capacity);
|
||||
}
|
||||
|
||||
value_type operator[](int64_t index) const { return GetValue(index); }
|
||||
|
||||
value_type& operator[](int64_t index) {
|
||||
return reinterpret_cast<value_type*>(data_builder_.mutable_data())[index];
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a contiguous C array of values
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
|
||||
/// indicates a valid (non-null) value
|
||||
/// \return Status
|
||||
Status AppendValues(const value_type* values, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(values, length);
|
||||
// length_ is update by these
|
||||
ArrayBuilder::UnsafeAppendToBitmap(valid_bytes, length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a contiguous C array of values
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
|
||||
/// (0). Equal in length to values
|
||||
/// \return Status
|
||||
Status AppendValues(const value_type* values, int64_t length,
|
||||
const std::vector<bool>& is_valid) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(values, length);
|
||||
// length_ is update by these
|
||||
ArrayBuilder::UnsafeAppendToBitmap(is_valid);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a std::vector of values
|
||||
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
|
||||
/// (0). Equal in length to values
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<value_type>& values,
|
||||
const std::vector<bool>& is_valid) {
|
||||
return AppendValues(values.data(), static_cast<int64_t>(values.size()), is_valid);
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a std::vector of values
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<value_type>& values) {
|
||||
return AppendValues(values.data(), static_cast<int64_t>(values.size()));
|
||||
}
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
|
||||
std::shared_ptr<Buffer> data, null_bitmap;
|
||||
ARROW_RETURN_NOT_OK(null_bitmap_builder_.Finish(&null_bitmap));
|
||||
ARROW_RETURN_NOT_OK(data_builder_.Finish(&data));
|
||||
*out = ArrayData::Make(type_, length_, {null_bitmap, data}, null_count_);
|
||||
capacity_ = length_ = null_count_ = 0;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<ArrayType>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values_begin InputIterator to the beginning of the values
|
||||
/// \param[in] values_end InputIterator pointing to the end of the values
|
||||
/// \return Status
|
||||
template <typename ValuesIter>
|
||||
Status AppendValues(ValuesIter values_begin, ValuesIter values_end) {
|
||||
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(values_begin, values_end);
|
||||
// this updates the length_
|
||||
UnsafeSetNotNull(length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot, with a specified nullmap
|
||||
/// \param[in] values_begin InputIterator to the beginning of the values
|
||||
/// \param[in] values_end InputIterator pointing to the end of the values
|
||||
/// \param[in] valid_begin InputIterator with elements indication valid(1)
|
||||
/// or null(0) values.
|
||||
/// \return Status
|
||||
template <typename ValuesIter, typename ValidIter>
|
||||
typename std::enable_if<!std::is_pointer<ValidIter>::value, Status>::type AppendValues(
|
||||
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
|
||||
static_assert(!internal::is_null_pointer<ValidIter>::value,
|
||||
"Don't pass a NULLPTR directly as valid_begin, use the 2-argument "
|
||||
"version instead");
|
||||
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(values_begin, values_end);
|
||||
null_bitmap_builder_.UnsafeAppend<true>(
|
||||
length, [&valid_begin]() -> bool { return *valid_begin++; });
|
||||
length_ = null_bitmap_builder_.length();
|
||||
null_count_ = null_bitmap_builder_.false_count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Same as above, with a pointer type ValidIter
|
||||
template <typename ValuesIter, typename ValidIter>
|
||||
typename std::enable_if<std::is_pointer<ValidIter>::value, Status>::type AppendValues(
|
||||
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
|
||||
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(values_begin, values_end);
|
||||
// this updates the length_
|
||||
if (valid_begin == NULLPTR) {
|
||||
UnsafeSetNotNull(length);
|
||||
} else {
|
||||
null_bitmap_builder_.UnsafeAppend<true>(
|
||||
length, [&valid_begin]() -> bool { return *valid_begin++; });
|
||||
length_ = null_bitmap_builder_.length();
|
||||
null_count_ = null_bitmap_builder_.false_count();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Append a single scalar under the assumption that the underlying Buffer is
|
||||
/// large enough.
|
||||
///
|
||||
/// This method does not capacity-check; make sure to call Reserve
|
||||
/// beforehand.
|
||||
void UnsafeAppend(const value_type val) {
|
||||
ArrayBuilder::UnsafeAppendToBitmap(true);
|
||||
data_builder_.UnsafeAppend(val);
|
||||
}
|
||||
|
||||
void UnsafeAppendNull() {
|
||||
ArrayBuilder::UnsafeAppendToBitmap(false);
|
||||
data_builder_.UnsafeAppend(0);
|
||||
}
|
||||
|
||||
protected:
|
||||
TypedBufferBuilder<value_type> data_builder_;
|
||||
};
|
||||
|
||||
// Builders
|
||||
|
||||
using UInt8Builder = NumericBuilder<UInt8Type>;
|
||||
using UInt16Builder = NumericBuilder<UInt16Type>;
|
||||
using UInt32Builder = NumericBuilder<UInt32Type>;
|
||||
using UInt64Builder = NumericBuilder<UInt64Type>;
|
||||
|
||||
using Int8Builder = NumericBuilder<Int8Type>;
|
||||
using Int16Builder = NumericBuilder<Int16Type>;
|
||||
using Int32Builder = NumericBuilder<Int32Type>;
|
||||
using Int64Builder = NumericBuilder<Int64Type>;
|
||||
|
||||
using HalfFloatBuilder = NumericBuilder<HalfFloatType>;
|
||||
using FloatBuilder = NumericBuilder<FloatType>;
|
||||
using DoubleBuilder = NumericBuilder<DoubleType>;
|
||||
|
||||
class ARROW_EXPORT BooleanBuilder : public ArrayBuilder {
|
||||
public:
|
||||
using value_type = bool;
|
||||
explicit BooleanBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
|
||||
explicit BooleanBuilder(const std::shared_ptr<DataType>& type, MemoryPool* pool);
|
||||
|
||||
/// Write nulls as uint8_t* (0 value indicates null) into pre-allocated memory
|
||||
Status AppendNulls(int64_t length) final {
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend(length, false);
|
||||
UnsafeSetNull(length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendNull() final {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppendNull();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// Scalar append
|
||||
Status Append(const bool val) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppend(val);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(const uint8_t val) { return Append(val != 0); }
|
||||
|
||||
/// Scalar append, without checking for capacity
|
||||
void UnsafeAppend(const bool val) {
|
||||
data_builder_.UnsafeAppend(val);
|
||||
UnsafeAppendToBitmap(true);
|
||||
}
|
||||
|
||||
void UnsafeAppendNull() {
|
||||
data_builder_.UnsafeAppend(false);
|
||||
UnsafeAppendToBitmap(false);
|
||||
}
|
||||
|
||||
void UnsafeAppend(const uint8_t val) { UnsafeAppend(val != 0); }
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a contiguous array of bytes (non-zero is 1)
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] valid_bytes an optional sequence of bytes where non-zero
|
||||
/// indicates a valid (non-null) value
|
||||
/// \return Status
|
||||
Status AppendValues(const uint8_t* values, int64_t length,
|
||||
const uint8_t* valid_bytes = NULLPTR);
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a contiguous C array of values
|
||||
/// \param[in] length the number of values to append
|
||||
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
|
||||
/// (0). Equal in length to values
|
||||
/// \return Status
|
||||
Status AppendValues(const uint8_t* values, int64_t length,
|
||||
const std::vector<bool>& is_valid);
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a std::vector of bytes
|
||||
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
|
||||
/// (0). Equal in length to values
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<uint8_t>& values,
|
||||
const std::vector<bool>& is_valid);
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values a std::vector of bytes
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<uint8_t>& values);
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values an std::vector<bool> indicating true (1) or false
|
||||
/// \param[in] is_valid an std::vector<bool> indicating valid (1) or null
|
||||
/// (0). Equal in length to values
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<bool>& values, const std::vector<bool>& is_valid);
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values an std::vector<bool> indicating true (1) or false
|
||||
/// \return Status
|
||||
Status AppendValues(const std::vector<bool>& values);
|
||||
|
||||
/// \brief Append a sequence of elements in one shot
|
||||
/// \param[in] values_begin InputIterator to the beginning of the values
|
||||
/// \param[in] values_end InputIterator pointing to the end of the values
|
||||
/// or null(0) values
|
||||
/// \return Status
|
||||
template <typename ValuesIter>
|
||||
Status AppendValues(ValuesIter values_begin, ValuesIter values_end) {
|
||||
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend<false>(
|
||||
length, [&values_begin]() -> bool { return *values_begin++; });
|
||||
// this updates length_
|
||||
UnsafeSetNotNull(length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append a sequence of elements in one shot, with a specified nullmap
|
||||
/// \param[in] values_begin InputIterator to the beginning of the values
|
||||
/// \param[in] values_end InputIterator pointing to the end of the values
|
||||
/// \param[in] valid_begin InputIterator with elements indication valid(1)
|
||||
/// or null(0) values
|
||||
/// \return Status
|
||||
template <typename ValuesIter, typename ValidIter>
|
||||
typename std::enable_if<!std::is_pointer<ValidIter>::value, Status>::type AppendValues(
|
||||
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
|
||||
static_assert(!internal::is_null_pointer<ValidIter>::value,
|
||||
"Don't pass a NULLPTR directly as valid_begin, use the 2-argument "
|
||||
"version instead");
|
||||
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
|
||||
data_builder_.UnsafeAppend<false>(
|
||||
length, [&values_begin]() -> bool { return *values_begin++; });
|
||||
null_bitmap_builder_.UnsafeAppend<true>(
|
||||
length, [&valid_begin]() -> bool { return *valid_begin++; });
|
||||
length_ = null_bitmap_builder_.length();
|
||||
null_count_ = null_bitmap_builder_.false_count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Same as above, for a pointer type ValidIter
|
||||
template <typename ValuesIter, typename ValidIter>
|
||||
typename std::enable_if<std::is_pointer<ValidIter>::value, Status>::type AppendValues(
|
||||
ValuesIter values_begin, ValuesIter values_end, ValidIter valid_begin) {
|
||||
int64_t length = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
data_builder_.UnsafeAppend<false>(
|
||||
length, [&values_begin]() -> bool { return *values_begin++; });
|
||||
|
||||
if (valid_begin == NULLPTR) {
|
||||
UnsafeSetNotNull(length);
|
||||
} else {
|
||||
null_bitmap_builder_.UnsafeAppend<true>(
|
||||
length, [&valid_begin]() -> bool { return *valid_begin++; });
|
||||
}
|
||||
length_ = null_bitmap_builder_.length();
|
||||
null_count_ = null_bitmap_builder_.false_count();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AppendValues(int64_t length, bool value);
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<BooleanArray>* out) { return FinishTyped(out); }
|
||||
|
||||
void Reset() override;
|
||||
Status Resize(int64_t capacity) override;
|
||||
|
||||
protected:
|
||||
TypedBufferBuilder<bool> data_builder_;
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,70 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Contains declarations of time related Arrow builder types.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/array/builder_base.h"
|
||||
#include "arrow/array/builder_binary.h"
|
||||
#include "arrow/array/builder_primitive.h"
|
||||
#include "arrow/buffer-builder.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/type_traits.h"
|
||||
#include "arrow/util/macros.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class ARROW_EXPORT DayTimeIntervalBuilder : public ArrayBuilder {
|
||||
public:
|
||||
using DayMilliseconds = DayTimeIntervalType::DayMilliseconds;
|
||||
|
||||
explicit DayTimeIntervalBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
|
||||
: DayTimeIntervalBuilder(day_time_interval(), pool) {}
|
||||
|
||||
DayTimeIntervalBuilder(std::shared_ptr<DataType> type,
|
||||
MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
|
||||
: ArrayBuilder(type, pool),
|
||||
builder_(fixed_size_binary(sizeof(DayMilliseconds)), pool) {}
|
||||
|
||||
void Reset() override { builder_.Reset(); }
|
||||
Status Resize(int64_t capacity) override { return builder_.Resize(capacity); }
|
||||
Status Append(DayMilliseconds day_millis) {
|
||||
return builder_.Append(reinterpret_cast<uint8_t*>(&day_millis));
|
||||
}
|
||||
void UnsafeAppend(DayMilliseconds day_millis) {
|
||||
builder_.UnsafeAppend(reinterpret_cast<uint8_t*>(&day_millis));
|
||||
}
|
||||
using ArrayBuilder::UnsafeAppendNull;
|
||||
Status AppendNull() override { return builder_.AppendNull(); }
|
||||
Status AppendNulls(int64_t length) override { return builder_.AppendNulls(length); }
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
|
||||
auto result = builder_.FinishInternal(out);
|
||||
if (*out != NULLPTR) {
|
||||
(*out)->type = type();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private:
|
||||
FixedSizeBinaryBuilder builder_;
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,106 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/array/builder_base.h"
|
||||
#include "arrow/buffer-builder.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
/// \class DenseUnionBuilder
|
||||
///
|
||||
/// You need to call AppendChild for each of the children builders you want
|
||||
/// to use. The function will return an int8_t, which is the type tag
|
||||
/// associated with that child. You can then call Append with that tag
|
||||
/// (followed by an append on the child builder) to add elements to
|
||||
/// the union array.
|
||||
///
|
||||
/// You can either specify the type when the UnionBuilder is constructed
|
||||
/// or let the UnionBuilder infer the type at runtime (by omitting the
|
||||
/// type argument from the constructor).
|
||||
///
|
||||
/// This API is EXPERIMENTAL.
|
||||
class ARROW_EXPORT DenseUnionBuilder : public ArrayBuilder {
|
||||
public:
|
||||
/// Use this constructor to incrementally build the union array along
|
||||
/// with types, offsets, and null bitmap.
|
||||
explicit DenseUnionBuilder(MemoryPool* pool,
|
||||
const std::shared_ptr<DataType>& type = NULLPTR);
|
||||
|
||||
Status AppendNull() final {
|
||||
ARROW_RETURN_NOT_OK(types_builder_.Append(0));
|
||||
ARROW_RETURN_NOT_OK(offsets_builder_.Append(0));
|
||||
return AppendToBitmap(false);
|
||||
}
|
||||
|
||||
Status AppendNulls(int64_t length) final {
|
||||
ARROW_RETURN_NOT_OK(types_builder_.Reserve(length));
|
||||
ARROW_RETURN_NOT_OK(offsets_builder_.Reserve(length));
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
for (int64_t i = 0; i < length; ++i) {
|
||||
types_builder_.UnsafeAppend(0);
|
||||
offsets_builder_.UnsafeAppend(0);
|
||||
}
|
||||
return AppendToBitmap(length, false);
|
||||
}
|
||||
|
||||
/// \brief Append an element to the UnionArray. This must be followed
|
||||
/// by an append to the appropriate child builder.
|
||||
/// \param[in] type index of the child the value will be appended
|
||||
/// \param[in] offset offset of the value in that child
|
||||
Status Append(int8_t type, int32_t offset) {
|
||||
ARROW_RETURN_NOT_OK(types_builder_.Append(type));
|
||||
ARROW_RETURN_NOT_OK(offsets_builder_.Append(offset));
|
||||
return AppendToBitmap(true);
|
||||
}
|
||||
|
||||
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
|
||||
|
||||
/// \cond FALSE
|
||||
using ArrayBuilder::Finish;
|
||||
/// \endcond
|
||||
|
||||
Status Finish(std::shared_ptr<UnionArray>* out) { return FinishTyped(out); }
|
||||
|
||||
/// \brief Make a new child builder available to the UnionArray
|
||||
///
|
||||
/// \param[in] child the child builder
|
||||
/// \param[in] field_name the name of the field in the union array type
|
||||
/// if type inference is used
|
||||
/// \return child index, which is the "type" argument that needs
|
||||
/// to be passed to the "Append" method to add a new element to
|
||||
/// the union array.
|
||||
int8_t AppendChild(const std::shared_ptr<ArrayBuilder>& child,
|
||||
const std::string& field_name = "") {
|
||||
children_.push_back(child);
|
||||
field_names_.push_back(field_name);
|
||||
return static_cast<int8_t>(children_.size() - 1);
|
||||
}
|
||||
|
||||
private:
|
||||
TypedBufferBuilder<int8_t> types_builder_;
|
||||
TypedBufferBuilder<int32_t> offsets_builder_;
|
||||
std::vector<std::string> field_names_;
|
||||
};
|
||||
|
||||
} // namespace arrow
|
@ -1,39 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/memory_pool.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
/// \brief Concatenate arrays
|
||||
///
|
||||
/// \param[in] arrays a vector of arrays to be concatenated
|
||||
/// \param[in] pool memory to store the result will be allocated from this memory pool
|
||||
/// \param[out] out the resulting concatenated array
|
||||
/// \return Status
|
||||
ARROW_EXPORT
|
||||
Status Concatenate(const ArrayVector& arrays, MemoryPool* pool,
|
||||
std::shared_ptr<Array>* out);
|
||||
|
||||
} // namespace arrow
|
@ -1,379 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_BUFFER_BUILDER_H
|
||||
#define ARROW_BUFFER_BUILDER_H
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "arrow/buffer.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/bit-util.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/ubsan.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Buffer builder classes
|
||||
|
||||
/// \class BufferBuilder
|
||||
/// \brief A class for incrementally building a contiguous chunk of in-memory
|
||||
/// data
|
||||
class ARROW_EXPORT BufferBuilder {
|
||||
public:
|
||||
explicit BufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
|
||||
: pool_(pool),
|
||||
data_(/*ensure never null to make ubsan happy and avoid check penalties below*/
|
||||
&util::internal::non_null_filler),
|
||||
|
||||
capacity_(0),
|
||||
size_(0) {}
|
||||
|
||||
/// \brief Resize the buffer to the nearest multiple of 64 bytes
|
||||
///
|
||||
/// \param new_capacity the new capacity of the of the builder. Will be
|
||||
/// rounded up to a multiple of 64 bytes for padding \param shrink_to_fit if
|
||||
/// new capacity is smaller than the existing size, reallocate internal
|
||||
/// buffer. Set to false to avoid reallocations when shrinking the builder.
|
||||
/// \return Status
|
||||
Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
|
||||
// Resize(0) is a no-op
|
||||
if (new_capacity == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
if (buffer_ == NULLPTR) {
|
||||
ARROW_RETURN_NOT_OK(AllocateResizableBuffer(pool_, new_capacity, &buffer_));
|
||||
} else {
|
||||
ARROW_RETURN_NOT_OK(buffer_->Resize(new_capacity, shrink_to_fit));
|
||||
}
|
||||
capacity_ = buffer_->capacity();
|
||||
data_ = buffer_->mutable_data();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Ensure that builder can accommodate the additional number of bytes
|
||||
/// without the need to perform allocations
|
||||
///
|
||||
/// \param[in] additional_bytes number of additional bytes to make space for
|
||||
/// \return Status
|
||||
Status Reserve(const int64_t additional_bytes) {
|
||||
auto min_capacity = size_ + additional_bytes;
|
||||
if (min_capacity <= capacity_) {
|
||||
return Status::OK();
|
||||
}
|
||||
return Resize(GrowByFactor(capacity_, min_capacity), false);
|
||||
}
|
||||
|
||||
/// \brief Return a capacity expanded by an unspecified growth factor
|
||||
static int64_t GrowByFactor(int64_t current_capacity, int64_t new_capacity) {
|
||||
// NOTE: Doubling isn't a great overallocation practice
|
||||
// see https://github.com/facebook/folly/blob/master/folly/docs/FBVector.md
|
||||
// for discussion.
|
||||
// Grow exactly if a large upsize (the caller might know the exact final size).
|
||||
// Otherwise overallocate by 1.5 to keep a linear amortized cost.
|
||||
return std::max(new_capacity, current_capacity * 3 / 2);
|
||||
}
|
||||
|
||||
/// \brief Append the given data to the buffer
|
||||
///
|
||||
/// The buffer is automatically expanded if necessary.
|
||||
Status Append(const void* data, const int64_t length) {
|
||||
if (ARROW_PREDICT_FALSE(size_ + length > capacity_)) {
|
||||
ARROW_RETURN_NOT_OK(Resize(GrowByFactor(capacity_, size_ + length), false));
|
||||
}
|
||||
UnsafeAppend(data, length);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/// \brief Append copies of a value to the buffer
|
||||
///
|
||||
/// The buffer is automatically expanded if necessary.
|
||||
Status Append(const int64_t num_copies, uint8_t value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(num_copies));
|
||||
UnsafeAppend(num_copies, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Advance pointer and zero out memory
|
||||
Status Advance(const int64_t length) { return Append(length, 0); }
|
||||
|
||||
// Advance pointer, but don't allocate or zero memory
|
||||
void UnsafeAdvance(const int64_t length) { size_ += length; }
|
||||
|
||||
// Unsafe methods don't check existing size
|
||||
void UnsafeAppend(const void* data, const int64_t length) {
|
||||
memcpy(data_ + size_, data, static_cast<size_t>(length));
|
||||
size_ += length;
|
||||
}
|
||||
|
||||
void UnsafeAppend(const int64_t num_copies, uint8_t value) {
|
||||
memset(data_ + size_, value, static_cast<size_t>(num_copies));
|
||||
size_ += num_copies;
|
||||
}
|
||||
|
||||
/// \brief Return result of builder as a Buffer object.
|
||||
///
|
||||
/// The builder is reset and can be reused afterwards.
|
||||
///
|
||||
/// \param[out] out the finalized Buffer object
|
||||
/// \param shrink_to_fit if the buffer size is smaller than its capacity,
|
||||
/// reallocate to fit more tightly in memory. Set to false to avoid
|
||||
/// a reallocation, at the expense of potentially more memory consumption.
|
||||
/// \return Status
|
||||
Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
|
||||
ARROW_RETURN_NOT_OK(Resize(size_, shrink_to_fit));
|
||||
if (size_ != 0) buffer_->ZeroPadding();
|
||||
*out = buffer_;
|
||||
if (*out == NULLPTR) {
|
||||
ARROW_RETURN_NOT_OK(AllocateBuffer(pool_, 0, out));
|
||||
}
|
||||
Reset();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
buffer_ = NULLPTR;
|
||||
capacity_ = size_ = 0;
|
||||
}
|
||||
|
||||
/// \brief Set size to a smaller value without modifying builder
|
||||
/// contents. For reusable BufferBuilder classes
|
||||
/// \param[in] position must be non-negative and less than or equal
|
||||
/// to the current length()
|
||||
void Rewind(int64_t position) { size_ = position; }
|
||||
|
||||
int64_t capacity() const { return capacity_; }
|
||||
int64_t length() const { return size_; }
|
||||
const uint8_t* data() const { return data_; }
|
||||
uint8_t* mutable_data() { return data_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<ResizableBuffer> buffer_;
|
||||
MemoryPool* pool_;
|
||||
uint8_t* data_;
|
||||
int64_t capacity_;
|
||||
int64_t size_;
|
||||
};
|
||||
|
||||
template <typename T, typename Enable = void>
|
||||
class TypedBufferBuilder;
|
||||
|
||||
/// \brief A BufferBuilder for building a buffer of arithmetic elements
|
||||
template <typename T>
|
||||
class TypedBufferBuilder<T, typename std::enable_if<std::is_arithmetic<T>::value>::type> {
|
||||
public:
|
||||
explicit TypedBufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
|
||||
: bytes_builder_(pool) {}
|
||||
|
||||
Status Append(T value) {
|
||||
return bytes_builder_.Append(reinterpret_cast<uint8_t*>(&value), sizeof(T));
|
||||
}
|
||||
|
||||
Status Append(const T* values, int64_t num_elements) {
|
||||
return bytes_builder_.Append(reinterpret_cast<const uint8_t*>(values),
|
||||
num_elements * sizeof(T));
|
||||
}
|
||||
|
||||
Status Append(const int64_t num_copies, T value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(num_copies + length()));
|
||||
UnsafeAppend(num_copies, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void UnsafeAppend(T value) {
|
||||
bytes_builder_.UnsafeAppend(reinterpret_cast<uint8_t*>(&value), sizeof(T));
|
||||
}
|
||||
|
||||
void UnsafeAppend(const T* values, int64_t num_elements) {
|
||||
bytes_builder_.UnsafeAppend(reinterpret_cast<const uint8_t*>(values),
|
||||
num_elements * sizeof(T));
|
||||
}
|
||||
|
||||
template <typename Iter>
|
||||
void UnsafeAppend(Iter values_begin, Iter values_end) {
|
||||
int64_t num_elements = static_cast<int64_t>(std::distance(values_begin, values_end));
|
||||
auto data = mutable_data() + length();
|
||||
bytes_builder_.UnsafeAdvance(num_elements * sizeof(T));
|
||||
std::copy(values_begin, values_end, data);
|
||||
}
|
||||
|
||||
void UnsafeAppend(const int64_t num_copies, T value) {
|
||||
auto data = mutable_data() + length();
|
||||
bytes_builder_.UnsafeAppend(num_copies * sizeof(T), 0);
|
||||
for (const auto end = data + num_copies; data != end; ++data) {
|
||||
*data = value;
|
||||
}
|
||||
}
|
||||
|
||||
Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
|
||||
return bytes_builder_.Resize(new_capacity * sizeof(T), shrink_to_fit);
|
||||
}
|
||||
|
||||
Status Reserve(const int64_t additional_elements) {
|
||||
return bytes_builder_.Reserve(additional_elements * sizeof(T));
|
||||
}
|
||||
|
||||
Status Advance(const int64_t length) {
|
||||
return bytes_builder_.Advance(length * sizeof(T));
|
||||
}
|
||||
|
||||
Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
|
||||
return bytes_builder_.Finish(out, shrink_to_fit);
|
||||
}
|
||||
|
||||
void Reset() { bytes_builder_.Reset(); }
|
||||
|
||||
int64_t length() const { return bytes_builder_.length() / sizeof(T); }
|
||||
int64_t capacity() const { return bytes_builder_.capacity() / sizeof(T); }
|
||||
const T* data() const { return reinterpret_cast<const T*>(bytes_builder_.data()); }
|
||||
T* mutable_data() { return reinterpret_cast<T*>(bytes_builder_.mutable_data()); }
|
||||
|
||||
private:
|
||||
BufferBuilder bytes_builder_;
|
||||
};
|
||||
|
||||
/// \brief A BufferBuilder for building a buffer containing a bitmap
|
||||
template <>
|
||||
class TypedBufferBuilder<bool> {
|
||||
public:
|
||||
explicit TypedBufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT)
|
||||
: bytes_builder_(pool) {}
|
||||
|
||||
Status Append(bool value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(1));
|
||||
UnsafeAppend(value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(const uint8_t* valid_bytes, int64_t num_elements) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(num_elements));
|
||||
UnsafeAppend(valid_bytes, num_elements);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Append(const int64_t num_copies, bool value) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(num_copies));
|
||||
UnsafeAppend(num_copies, value);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void UnsafeAppend(bool value) {
|
||||
BitUtil::SetBitTo(mutable_data(), bit_length_, value);
|
||||
if (!value) {
|
||||
++false_count_;
|
||||
}
|
||||
++bit_length_;
|
||||
}
|
||||
|
||||
void UnsafeAppend(const uint8_t* bytes, int64_t num_elements) {
|
||||
if (num_elements == 0) return;
|
||||
int64_t i = 0;
|
||||
internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
|
||||
bool value = bytes[i++];
|
||||
false_count_ += !value;
|
||||
return value;
|
||||
});
|
||||
bit_length_ += num_elements;
|
||||
}
|
||||
|
||||
void UnsafeAppend(const int64_t num_copies, bool value) {
|
||||
BitUtil::SetBitsTo(mutable_data(), bit_length_, num_copies, value);
|
||||
false_count_ += num_copies * !value;
|
||||
bit_length_ += num_copies;
|
||||
}
|
||||
|
||||
template <bool count_falses, typename Generator>
|
||||
void UnsafeAppend(const int64_t num_elements, Generator&& gen) {
|
||||
if (num_elements == 0) return;
|
||||
|
||||
if (count_falses) {
|
||||
internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements, [&] {
|
||||
bool value = gen();
|
||||
false_count_ += !value;
|
||||
return value;
|
||||
});
|
||||
} else {
|
||||
internal::GenerateBitsUnrolled(mutable_data(), bit_length_, num_elements,
|
||||
std::forward<Generator>(gen));
|
||||
}
|
||||
bit_length_ += num_elements;
|
||||
}
|
||||
|
||||
Status Resize(const int64_t new_capacity, bool shrink_to_fit = true) {
|
||||
const int64_t old_byte_capacity = bytes_builder_.capacity();
|
||||
ARROW_RETURN_NOT_OK(
|
||||
bytes_builder_.Resize(BitUtil::BytesForBits(new_capacity), shrink_to_fit));
|
||||
// Resize() may have chosen a larger capacity (e.g. for padding),
|
||||
// so ask it again before calling memset().
|
||||
const int64_t new_byte_capacity = bytes_builder_.capacity();
|
||||
if (new_byte_capacity > old_byte_capacity) {
|
||||
// The additional buffer space is 0-initialized for convenience,
|
||||
// so that other methods can simply bump the length.
|
||||
memset(mutable_data() + old_byte_capacity, 0,
|
||||
static_cast<size_t>(new_byte_capacity - old_byte_capacity));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Reserve(const int64_t additional_elements) {
|
||||
return Resize(
|
||||
BufferBuilder::GrowByFactor(bit_length_, bit_length_ + additional_elements),
|
||||
false);
|
||||
}
|
||||
|
||||
Status Advance(const int64_t length) {
|
||||
ARROW_RETURN_NOT_OK(Reserve(length));
|
||||
bit_length_ += length;
|
||||
false_count_ += length;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Finish(std::shared_ptr<Buffer>* out, bool shrink_to_fit = true) {
|
||||
// set bytes_builder_.size_ == byte size of data
|
||||
bytes_builder_.UnsafeAdvance(BitUtil::BytesForBits(bit_length_) -
|
||||
bytes_builder_.length());
|
||||
bit_length_ = false_count_ = 0;
|
||||
return bytes_builder_.Finish(out, shrink_to_fit);
|
||||
}
|
||||
|
||||
void Reset() {
|
||||
bytes_builder_.Reset();
|
||||
bit_length_ = false_count_ = 0;
|
||||
}
|
||||
|
||||
int64_t length() const { return bit_length_; }
|
||||
int64_t capacity() const { return bytes_builder_.capacity() * 8; }
|
||||
const uint8_t* data() const { return bytes_builder_.data(); }
|
||||
uint8_t* mutable_data() { return bytes_builder_.mutable_data(); }
|
||||
int64_t false_count() const { return false_count_; }
|
||||
|
||||
private:
|
||||
BufferBuilder bytes_builder_;
|
||||
int64_t bit_length_ = 0;
|
||||
int64_t false_count_ = 0;
|
||||
};
|
||||
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_BUFFER_BUILDER_H
|
444
cpp/thirdparty/knowhere_build/include/arrow/buffer.h
vendored
444
cpp/thirdparty/knowhere_build/include/arrow/buffer.h
vendored
@ -1,444 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_BUFFER_H
|
||||
#define ARROW_BUFFER_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/memory_pool.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/string_view.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
// ----------------------------------------------------------------------
|
||||
// Buffer classes
|
||||
|
||||
/// \class Buffer
|
||||
/// \brief Object containing a pointer to a piece of contiguous memory with a
|
||||
/// particular size.
|
||||
///
|
||||
/// Buffers have two related notions of length: size and capacity. Size is
|
||||
/// the number of bytes that might have valid data. Capacity is the number
|
||||
/// of bytes that were allocated for the buffer in total.
|
||||
///
|
||||
/// The Buffer base class does not own its memory, but subclasses often do.
|
||||
///
|
||||
/// The following invariant is always true: Size <= Capacity
|
||||
class ARROW_EXPORT Buffer {
|
||||
public:
|
||||
/// \brief Construct from buffer and size without copying memory
|
||||
///
|
||||
/// \param[in] data a memory buffer
|
||||
/// \param[in] size buffer size
|
||||
///
|
||||
/// \note The passed memory must be kept alive through some other means
|
||||
Buffer(const uint8_t* data, int64_t size)
|
||||
: is_mutable_(false),
|
||||
data_(data),
|
||||
mutable_data_(NULLPTR),
|
||||
size_(size),
|
||||
capacity_(size) {}
|
||||
|
||||
/// \brief Construct from string_view without copying memory
|
||||
///
|
||||
/// \param[in] data a string_view object
|
||||
///
|
||||
/// \note The memory viewed by data must not be deallocated in the lifetime of the
|
||||
/// Buffer; temporary rvalue strings must be stored in an lvalue somewhere
|
||||
explicit Buffer(util::string_view data)
|
||||
: Buffer(reinterpret_cast<const uint8_t*>(data.data()),
|
||||
static_cast<int64_t>(data.size())) {}
|
||||
|
||||
virtual ~Buffer() = default;
|
||||
|
||||
/// An offset into data that is owned by another buffer, but we want to be
|
||||
/// able to retain a valid pointer to it even after other shared_ptr's to the
|
||||
/// parent buffer have been destroyed
|
||||
///
|
||||
/// This method makes no assertions about alignment or padding of the buffer but
|
||||
/// in general we expected buffers to be aligned and padded to 64 bytes. In the future
|
||||
/// we might add utility methods to help determine if a buffer satisfies this contract.
|
||||
Buffer(const std::shared_ptr<Buffer>& parent, const int64_t offset, const int64_t size)
|
||||
: Buffer(parent->data() + offset, size) {
|
||||
parent_ = parent;
|
||||
}
|
||||
|
||||
uint8_t operator[](std::size_t i) const { return data_[i]; }
|
||||
|
||||
bool is_mutable() const { return is_mutable_; }
|
||||
|
||||
/// \brief Construct a new std::string with a hexadecimal representation of the buffer.
|
||||
/// \return std::string
|
||||
std::string ToHexString();
|
||||
|
||||
/// Return true if both buffers are the same size and contain the same bytes
|
||||
/// up to the number of compared bytes
|
||||
bool Equals(const Buffer& other, int64_t nbytes) const;
|
||||
|
||||
/// Return true if both buffers are the same size and contain the same bytes
|
||||
bool Equals(const Buffer& other) const;
|
||||
|
||||
/// Copy a section of the buffer into a new Buffer.
|
||||
Status Copy(const int64_t start, const int64_t nbytes, MemoryPool* pool,
|
||||
std::shared_ptr<Buffer>* out) const;
|
||||
|
||||
/// Copy a section of the buffer using the default memory pool into a new Buffer.
|
||||
Status Copy(const int64_t start, const int64_t nbytes,
|
||||
std::shared_ptr<Buffer>* out) const;
|
||||
|
||||
/// Zero bytes in padding, i.e. bytes between size_ and capacity_.
|
||||
void ZeroPadding() {
|
||||
#ifndef NDEBUG
|
||||
CheckMutable();
|
||||
#endif
|
||||
// A zero-capacity buffer can have a null data pointer
|
||||
if (capacity_ != 0) {
|
||||
memset(mutable_data_ + size_, 0, static_cast<size_t>(capacity_ - size_));
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief Construct a new buffer that owns its memory from a std::string
|
||||
///
|
||||
/// \param[in] data a std::string object
|
||||
/// \param[in] pool a memory pool
|
||||
/// \param[out] out the created buffer
|
||||
///
|
||||
/// \return Status message
|
||||
static Status FromString(const std::string& data, MemoryPool* pool,
|
||||
std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Construct a new buffer that owns its memory from a std::string
|
||||
/// using the default memory pool
|
||||
static Status FromString(const std::string& data, std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Construct an immutable buffer that takes ownership of the contents
|
||||
/// of an std::string
|
||||
/// \param[in] data an rvalue-reference of a string
|
||||
/// \return a new Buffer instance
|
||||
static std::shared_ptr<Buffer> FromString(std::string&& data);
|
||||
|
||||
/// \brief Create buffer referencing typed memory with some length without
|
||||
/// copying
|
||||
/// \param[in] data the typed memory as C array
|
||||
/// \param[in] length the number of values in the array
|
||||
/// \return a new shared_ptr<Buffer>
|
||||
template <typename T, typename SizeType = int64_t>
|
||||
static std::shared_ptr<Buffer> Wrap(const T* data, SizeType length) {
|
||||
return std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(data),
|
||||
static_cast<int64_t>(sizeof(T) * length));
|
||||
}
|
||||
|
||||
/// \brief Create buffer referencing std::vector with some length without
|
||||
/// copying
|
||||
/// \param[in] data the vector to be referenced. If this vector is changed,
|
||||
/// the buffer may become invalid
|
||||
/// \return a new shared_ptr<Buffer>
|
||||
template <typename T>
|
||||
static std::shared_ptr<Buffer> Wrap(const std::vector<T>& data) {
|
||||
return std::make_shared<Buffer>(reinterpret_cast<const uint8_t*>(data.data()),
|
||||
static_cast<int64_t>(sizeof(T) * data.size()));
|
||||
}
|
||||
|
||||
/// \brief Copy buffer contents into a new std::string
|
||||
/// \return std::string
|
||||
/// \note Can throw std::bad_alloc if buffer is large
|
||||
std::string ToString() const;
|
||||
|
||||
/// \brief View buffer contents as a util::string_view
|
||||
/// \return util::string_view
|
||||
explicit operator util::string_view() const {
|
||||
return util::string_view(reinterpret_cast<const char*>(data_), size_);
|
||||
}
|
||||
|
||||
/// \brief Return a pointer to the buffer's data
|
||||
const uint8_t* data() const { return data_; }
|
||||
/// \brief Return a writable pointer to the buffer's data
|
||||
///
|
||||
/// The buffer has to be mutable. Otherwise, an assertion may be thrown
|
||||
/// or a null pointer may be returned.
|
||||
uint8_t* mutable_data() {
|
||||
#ifndef NDEBUG
|
||||
CheckMutable();
|
||||
#endif
|
||||
return mutable_data_;
|
||||
}
|
||||
|
||||
/// \brief Return the buffer's size in bytes
|
||||
int64_t size() const { return size_; }
|
||||
|
||||
/// \brief Return the buffer's capacity (number of allocated bytes)
|
||||
int64_t capacity() const { return capacity_; }
|
||||
|
||||
std::shared_ptr<Buffer> parent() const { return parent_; }
|
||||
|
||||
protected:
|
||||
bool is_mutable_;
|
||||
const uint8_t* data_;
|
||||
uint8_t* mutable_data_;
|
||||
int64_t size_;
|
||||
int64_t capacity_;
|
||||
|
||||
// null by default, but may be set
|
||||
std::shared_ptr<Buffer> parent_;
|
||||
|
||||
void CheckMutable() const;
|
||||
|
||||
private:
|
||||
ARROW_DISALLOW_COPY_AND_ASSIGN(Buffer);
|
||||
};
|
||||
|
||||
using BufferVector = std::vector<std::shared_ptr<Buffer>>;
|
||||
|
||||
/// \defgroup buffer-slicing-functions Functions for slicing buffers
|
||||
///
|
||||
/// @{
|
||||
|
||||
/// \brief Construct a view on a buffer at the given offset and length.
|
||||
///
|
||||
/// This function cannot fail and does not check for errors (except in debug builds)
|
||||
static inline std::shared_ptr<Buffer> SliceBuffer(const std::shared_ptr<Buffer>& buffer,
|
||||
const int64_t offset,
|
||||
const int64_t length) {
|
||||
return std::make_shared<Buffer>(buffer, offset, length);
|
||||
}
|
||||
|
||||
/// \brief Construct a view on a buffer at the given offset, up to the buffer's end.
|
||||
///
|
||||
/// This function cannot fail and does not check for errors (except in debug builds)
|
||||
static inline std::shared_ptr<Buffer> SliceBuffer(const std::shared_ptr<Buffer>& buffer,
|
||||
const int64_t offset) {
|
||||
int64_t length = buffer->size() - offset;
|
||||
return SliceBuffer(buffer, offset, length);
|
||||
}
|
||||
|
||||
/// \brief Like SliceBuffer, but construct a mutable buffer slice.
|
||||
///
|
||||
/// If the parent buffer is not mutable, behavior is undefined (it may abort
|
||||
/// in debug builds).
|
||||
ARROW_EXPORT
|
||||
std::shared_ptr<Buffer> SliceMutableBuffer(const std::shared_ptr<Buffer>& buffer,
|
||||
const int64_t offset, const int64_t length);
|
||||
|
||||
/// \brief Like SliceBuffer, but construct a mutable buffer slice.
|
||||
///
|
||||
/// If the parent buffer is not mutable, behavior is undefined (it may abort
|
||||
/// in debug builds).
|
||||
static inline std::shared_ptr<Buffer> SliceMutableBuffer(
|
||||
const std::shared_ptr<Buffer>& buffer, const int64_t offset) {
|
||||
int64_t length = buffer->size() - offset;
|
||||
return SliceMutableBuffer(buffer, offset, length);
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
/// \class MutableBuffer
|
||||
/// \brief A Buffer whose contents can be mutated. May or may not own its data.
|
||||
class ARROW_EXPORT MutableBuffer : public Buffer {
|
||||
public:
|
||||
MutableBuffer(uint8_t* data, const int64_t size) : Buffer(data, size) {
|
||||
mutable_data_ = data;
|
||||
is_mutable_ = true;
|
||||
}
|
||||
|
||||
MutableBuffer(const std::shared_ptr<Buffer>& parent, const int64_t offset,
|
||||
const int64_t size);
|
||||
|
||||
/// \brief Create buffer referencing typed memory with some length
|
||||
/// \param[in] data the typed memory as C array
|
||||
/// \param[in] length the number of values in the array
|
||||
/// \return a new shared_ptr<Buffer>
|
||||
template <typename T, typename SizeType = int64_t>
|
||||
static std::shared_ptr<Buffer> Wrap(T* data, SizeType length) {
|
||||
return std::make_shared<MutableBuffer>(reinterpret_cast<uint8_t*>(data),
|
||||
static_cast<int64_t>(sizeof(T) * length));
|
||||
}
|
||||
|
||||
protected:
|
||||
MutableBuffer() : Buffer(NULLPTR, 0) {}
|
||||
};
|
||||
|
||||
/// \class ResizableBuffer
|
||||
/// \brief A mutable buffer that can be resized
|
||||
class ARROW_EXPORT ResizableBuffer : public MutableBuffer {
|
||||
public:
|
||||
/// Change buffer reported size to indicated size, allocating memory if
|
||||
/// necessary. This will ensure that the capacity of the buffer is a multiple
|
||||
/// of 64 bytes as defined in Layout.md.
|
||||
/// Consider using ZeroPadding afterwards, to conform to the Arrow layout
|
||||
/// specification.
|
||||
///
|
||||
/// @param new_size The new size for the buffer.
|
||||
/// @param shrink_to_fit Whether to shrink the capacity if new size < current size
|
||||
virtual Status Resize(const int64_t new_size, bool shrink_to_fit = true) = 0;
|
||||
|
||||
/// Ensure that buffer has enough memory allocated to fit the indicated
|
||||
/// capacity (and meets the 64 byte padding requirement in Layout.md).
|
||||
/// It does not change buffer's reported size and doesn't zero the padding.
|
||||
virtual Status Reserve(const int64_t new_capacity) = 0;
|
||||
|
||||
template <class T>
|
||||
Status TypedResize(const int64_t new_nb_elements, bool shrink_to_fit = true) {
|
||||
return Resize(sizeof(T) * new_nb_elements, shrink_to_fit);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Status TypedReserve(const int64_t new_nb_elements) {
|
||||
return Reserve(sizeof(T) * new_nb_elements);
|
||||
}
|
||||
|
||||
protected:
|
||||
ResizableBuffer(uint8_t* data, int64_t size) : MutableBuffer(data, size) {}
|
||||
};
|
||||
|
||||
/// \defgroup buffer-allocation-functions Functions for allocating buffers
|
||||
///
|
||||
/// @{
|
||||
|
||||
/// \brief Allocate a fixed size mutable buffer from a memory pool, zero its padding.
|
||||
///
|
||||
/// \param[in] pool a memory pool
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer (contains padding)
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateBuffer(MemoryPool* pool, const int64_t size, std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Allocate a fixed size mutable buffer from a memory pool, zero its padding.
|
||||
///
|
||||
/// \param[in] pool a memory pool
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer (contains padding)
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateBuffer(MemoryPool* pool, const int64_t size, std::unique_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Allocate a fixed-size mutable buffer from the default memory pool
|
||||
///
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer (contains padding)
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateBuffer(const int64_t size, std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Allocate a fixed-size mutable buffer from the default memory pool
|
||||
///
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer (contains padding)
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateBuffer(const int64_t size, std::unique_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Allocate a resizeable buffer from a memory pool, zero its padding.
|
||||
///
|
||||
/// \param[in] pool a memory pool
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateResizableBuffer(MemoryPool* pool, const int64_t size,
|
||||
std::shared_ptr<ResizableBuffer>* out);
|
||||
|
||||
/// \brief Allocate a resizeable buffer from a memory pool, zero its padding.
|
||||
///
|
||||
/// \param[in] pool a memory pool
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateResizableBuffer(MemoryPool* pool, const int64_t size,
|
||||
std::unique_ptr<ResizableBuffer>* out);
|
||||
|
||||
/// \brief Allocate a resizeable buffer from the default memory pool
|
||||
///
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateResizableBuffer(const int64_t size, std::shared_ptr<ResizableBuffer>* out);
|
||||
|
||||
/// \brief Allocate a resizeable buffer from the default memory pool
|
||||
///
|
||||
/// \param[in] size size of buffer to allocate
|
||||
/// \param[out] out the allocated buffer
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateResizableBuffer(const int64_t size, std::unique_ptr<ResizableBuffer>* out);
|
||||
|
||||
/// \brief Allocate a bitmap buffer from a memory pool
|
||||
/// no guarantee on values is provided.
|
||||
///
|
||||
/// \param[in] pool memory pool to allocate memory from
|
||||
/// \param[in] length size in bits of bitmap to allocate
|
||||
/// \param[out] out the resulting buffer
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateBitmap(MemoryPool* pool, int64_t length, std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Allocate a zero-initialized bitmap buffer from a memory pool
|
||||
///
|
||||
/// \param[in] pool memory pool to allocate memory from
|
||||
/// \param[in] length size in bits of bitmap to allocate
|
||||
/// \param[out] out the resulting buffer (zero-initialized).
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateEmptyBitmap(MemoryPool* pool, int64_t length,
|
||||
std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Allocate a zero-initialized bitmap buffer from the default memory pool
|
||||
///
|
||||
/// \param[in] length size in bits of bitmap to allocate
|
||||
/// \param[out] out the resulting buffer
|
||||
///
|
||||
/// \return Status message
|
||||
ARROW_EXPORT
|
||||
Status AllocateEmptyBitmap(int64_t length, std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Concatenate multiple buffers into a single buffer
|
||||
///
|
||||
/// \param[in] buffers to be concatenated
|
||||
/// \param[in] pool memory pool to allocate the new buffer from
|
||||
/// \param[out] out the concatenated buffer
|
||||
///
|
||||
/// \return Status
|
||||
ARROW_EXPORT
|
||||
Status ConcatenateBuffers(const BufferVector& buffers, MemoryPool* pool,
|
||||
std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// @}
|
||||
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_BUFFER_H
|
@ -1,58 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/array/builder_adaptive.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_base.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_binary.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_decimal.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_dict.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_nested.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_primitive.h" // IWYU pragma: export
|
||||
#include "arrow/array/builder_time.h" // IWYU pragma: export
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class DataType;
|
||||
class MemoryPool;
|
||||
|
||||
/// \brief Construct an empty ArrayBuilder corresponding to the data
|
||||
/// type
|
||||
/// \param[in] pool the MemoryPool to use for allocations
|
||||
/// \param[in] type an instance of DictionaryType
|
||||
/// \param[out] out the created ArrayBuilder
|
||||
ARROW_EXPORT
|
||||
Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
|
||||
std::unique_ptr<ArrayBuilder>* out);
|
||||
|
||||
/// \brief Construct an empty DictionaryBuilder initialized optionally
|
||||
/// with a pre-existing dictionary
|
||||
/// \param[in] pool the MemoryPool to use for allocations
|
||||
/// \param[in] type an instance of DictionaryType
|
||||
/// \param[in] dictionary the initial dictionary, if any. May be nullptr
|
||||
/// \param[out] out the created ArrayBuilder
|
||||
ARROW_EXPORT
|
||||
Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
|
||||
const std::shared_ptr<Array>& dictionary,
|
||||
std::unique_ptr<ArrayBuilder>* out);
|
||||
|
||||
} // namespace arrow
|
@ -1,101 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Functions for comparing Arrow data structures
|
||||
|
||||
#ifndef ARROW_COMPARE_H
|
||||
#define ARROW_COMPARE_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
class Tensor;
|
||||
class SparseTensor;
|
||||
struct Scalar;
|
||||
|
||||
static constexpr double kDefaultAbsoluteTolerance = 1E-5;
|
||||
|
||||
/// A container of options for equality comparisons
|
||||
class EqualOptions {
|
||||
public:
|
||||
/// Whether or not NaNs are considered equal.
|
||||
bool nans_equal() const { return nans_equal_; }
|
||||
|
||||
/// Return a new EqualOptions object with the "nans_equal" property changed.
|
||||
EqualOptions nans_equal(bool v) const {
|
||||
auto res = EqualOptions(*this);
|
||||
res.nans_equal_ = v;
|
||||
return res;
|
||||
}
|
||||
|
||||
/// The absolute tolerance for approximate comparisons of floating-point values.
|
||||
double atol() const { return atol_; }
|
||||
|
||||
/// Return a new EqualOptions object with the "atol" property changed.
|
||||
EqualOptions atol(double v) const {
|
||||
auto res = EqualOptions(*this);
|
||||
res.atol_ = v;
|
||||
return res;
|
||||
}
|
||||
|
||||
static EqualOptions Defaults() { return EqualOptions(); }
|
||||
|
||||
protected:
|
||||
double atol_ = kDefaultAbsoluteTolerance;
|
||||
bool nans_equal_ = false;
|
||||
};
|
||||
|
||||
/// Returns true if the arrays are exactly equal
|
||||
bool ARROW_EXPORT ArrayEquals(const Array& left, const Array& right,
|
||||
const EqualOptions& = EqualOptions::Defaults());
|
||||
|
||||
bool ARROW_EXPORT TensorEquals(const Tensor& left, const Tensor& right);
|
||||
|
||||
/// EXPERIMENTAL: Returns true if the given sparse tensors are exactly equal
|
||||
bool ARROW_EXPORT SparseTensorEquals(const SparseTensor& left, const SparseTensor& right);
|
||||
|
||||
/// Returns true if the arrays are approximately equal. For non-floating point
|
||||
/// types, this is equivalent to ArrayEquals(left, right)
|
||||
bool ARROW_EXPORT ArrayApproxEquals(const Array& left, const Array& right,
|
||||
const EqualOptions& = EqualOptions::Defaults());
|
||||
|
||||
/// Returns true if indicated equal-length segment of arrays is exactly equal
|
||||
bool ARROW_EXPORT ArrayRangeEquals(const Array& left, const Array& right,
|
||||
int64_t start_idx, int64_t end_idx,
|
||||
int64_t other_start_idx);
|
||||
|
||||
/// Returns true if the type metadata are exactly equal
|
||||
/// \param[in] left a DataType
|
||||
/// \param[in] right a DataType
|
||||
/// \param[in] check_metadata whether to compare KeyValueMetadata for child
|
||||
/// fields
|
||||
bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
|
||||
bool check_metadata = true);
|
||||
|
||||
/// Returns true if scalars are equal
|
||||
/// \param[in] left a Scalar
|
||||
/// \param[in] right a Scalar
|
||||
bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right);
|
||||
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_COMPARE_H
|
@ -1,33 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_API_H
|
||||
#define ARROW_COMPUTE_API_H
|
||||
|
||||
#include "arrow/compute/context.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernel.h" // IWYU pragma: export
|
||||
|
||||
#include "arrow/compute/kernels/boolean.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/cast.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/compare.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/count.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/hash.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/mean.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/sum.h" // IWYU pragma: export
|
||||
#include "arrow/compute/kernels/take.h" // IWYU pragma: export
|
||||
|
||||
#endif // ARROW_COMPUTE_API_H
|
@ -1,97 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/testing/gtest_util.h"
|
||||
#include "arrow/util/cpu-info.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace compute {
|
||||
|
||||
using internal::CpuInfo;
|
||||
static CpuInfo* cpu_info = CpuInfo::GetInstance();
|
||||
|
||||
static const int64_t kL1Size = cpu_info->CacheSize(CpuInfo::L1_CACHE);
|
||||
static const int64_t kL2Size = cpu_info->CacheSize(CpuInfo::L2_CACHE);
|
||||
static const int64_t kL3Size = cpu_info->CacheSize(CpuInfo::L3_CACHE);
|
||||
static const int64_t kCantFitInL3Size = kL3Size * 4;
|
||||
static const std::vector<int64_t> kMemorySizes = {kL1Size, kL2Size, kL3Size,
|
||||
kCantFitInL3Size};
|
||||
|
||||
template <typename Func>
|
||||
struct BenchmarkArgsType;
|
||||
|
||||
// Pattern matching that extracts the vector element type of Benchmark::Args()
|
||||
template <typename Values>
|
||||
struct BenchmarkArgsType<benchmark::internal::Benchmark* (
|
||||
benchmark::internal::Benchmark::*)(const std::vector<Values>&)> {
|
||||
using type = Values;
|
||||
};
|
||||
|
||||
// Benchmark changed its parameter type between releases from
|
||||
// int to int64_t. As it doesn't have version macros, we need
|
||||
// to apply C++ template magic.
|
||||
using ArgsType =
|
||||
typename BenchmarkArgsType<decltype(&benchmark::internal::Benchmark::Args)>::type;
|
||||
|
||||
void BenchmarkSetArgsWithSizes(benchmark::internal::Benchmark* bench,
|
||||
const std::vector<int64_t>& sizes = kMemorySizes) {
|
||||
bench->Unit(benchmark::kMicrosecond);
|
||||
|
||||
for (auto size : sizes)
|
||||
for (auto nulls : std::vector<ArgsType>({0, 1, 10, 50}))
|
||||
bench->Args({static_cast<ArgsType>(size), nulls});
|
||||
}
|
||||
|
||||
void BenchmarkSetArgs(benchmark::internal::Benchmark* bench) {
|
||||
BenchmarkSetArgsWithSizes(bench, kMemorySizes);
|
||||
}
|
||||
|
||||
void RegressionSetArgs(benchmark::internal::Benchmark* bench) {
|
||||
// Regression do not need to account for cache hierarchy, thus optimize for
|
||||
// the best case.
|
||||
BenchmarkSetArgsWithSizes(bench, {kL1Size});
|
||||
}
|
||||
|
||||
// RAII struct to handle some of the boilerplate in regression benchmarks
|
||||
struct RegressionArgs {
|
||||
// size of memory tested (per iteration) in bytes
|
||||
const int64_t size;
|
||||
|
||||
// proportion of nulls in generated arrays
|
||||
const double null_proportion;
|
||||
|
||||
explicit RegressionArgs(benchmark::State& state)
|
||||
: size(state.range(0)),
|
||||
null_proportion(static_cast<double>(state.range(1)) / 100.0),
|
||||
state_(state) {}
|
||||
|
||||
~RegressionArgs() {
|
||||
state_.counters["size"] = static_cast<double>(size);
|
||||
state_.counters["null_percent"] = static_cast<double>(state_.range(1));
|
||||
state_.SetBytesProcessed(state_.iterations() * size);
|
||||
}
|
||||
|
||||
private:
|
||||
benchmark::State& state_;
|
||||
};
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,82 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_CONTEXT_H
|
||||
#define ARROW_COMPUTE_CONTEXT_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/memory_pool.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Buffer;
|
||||
|
||||
namespace internal {
|
||||
class CpuInfo;
|
||||
} // namespace internal
|
||||
|
||||
namespace compute {
|
||||
|
||||
#define RETURN_IF_ERROR(ctx) \
|
||||
if (ARROW_PREDICT_FALSE(ctx->HasError())) { \
|
||||
Status s = ctx->status(); \
|
||||
ctx->ResetStatus(); \
|
||||
return s; \
|
||||
}
|
||||
|
||||
/// \brief Container for variables and options used by function evaluation
|
||||
class ARROW_EXPORT FunctionContext {
|
||||
public:
|
||||
explicit FunctionContext(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT);
|
||||
MemoryPool* memory_pool() const;
|
||||
|
||||
/// \brief Allocate buffer from the context's memory pool
|
||||
Status Allocate(const int64_t nbytes, std::shared_ptr<Buffer>* out);
|
||||
|
||||
/// \brief Indicate that an error has occurred, to be checked by a parent caller
|
||||
/// \param[in] status a Status instance
|
||||
///
|
||||
/// \note Will not overwrite a prior set Status, so we will have the first
|
||||
/// error that occurred until FunctionContext::ResetStatus is called
|
||||
void SetStatus(const Status& status);
|
||||
|
||||
/// \brief Clear any error status
|
||||
void ResetStatus();
|
||||
|
||||
/// \brief Return true if an error has occurred
|
||||
bool HasError() const { return !status_.ok(); }
|
||||
|
||||
/// \brief Return the current status of the context
|
||||
const Status& status() const { return status_; }
|
||||
|
||||
internal::CpuInfo* cpu_info() const { return cpu_info_; }
|
||||
|
||||
private:
|
||||
Status status_;
|
||||
MemoryPool* pool_;
|
||||
internal::CpuInfo* cpu_info_;
|
||||
};
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_COMPUTE_CONTEXT_H
|
@ -1,261 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "arrow/compute/type_fwd.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace compute {
|
||||
|
||||
class LogicalType;
|
||||
class ExprVisitor;
|
||||
class Operation;
|
||||
|
||||
/// \brief Base class for all analytic expressions. Expressions may represent
|
||||
/// data values (scalars, arrays, tables)
|
||||
class ARROW_EXPORT Expr {
|
||||
public:
|
||||
/// \brief Instantiate expression from an abstract operation
|
||||
/// \param[in] op the operation that generates the expression
|
||||
explicit Expr(ConstOpPtr op);
|
||||
|
||||
virtual ~Expr() = default;
|
||||
|
||||
/// \brief A unique string identifier for the kind of expression
|
||||
virtual std::string kind() const = 0;
|
||||
|
||||
/// \brief Accept expression visitor
|
||||
/// TODO(wesm)
|
||||
// virtual Status Accept(ExprVisitor* visitor) const = 0;
|
||||
|
||||
/// \brief The underlying operation
|
||||
ConstOpPtr op() const { return op_; }
|
||||
|
||||
protected:
|
||||
ConstOpPtr op_;
|
||||
};
|
||||
|
||||
/// The value cardinality: one or many. These correspond to the arrow::Scalar
|
||||
/// and arrow::Array types
|
||||
enum class ValueRank { SCALAR, ARRAY };
|
||||
|
||||
/// \brief Base class for a data-generated expression with a fixed and known
|
||||
/// type. This includes arrays and scalars
|
||||
class ARROW_EXPORT ValueExpr : public Expr {
|
||||
public:
|
||||
/// \brief The name of the expression, if any. The default is unnamed
|
||||
// virtual const ExprName& name() const;
|
||||
LogicalTypePtr type() const;
|
||||
|
||||
/// \brief The value cardinality (scalar or array) of the expression
|
||||
virtual ValueRank rank() const = 0;
|
||||
|
||||
protected:
|
||||
ValueExpr(ConstOpPtr op, LogicalTypePtr type);
|
||||
|
||||
/// \brief The semantic data type of the expression
|
||||
LogicalTypePtr type_;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT ArrayExpr : public ValueExpr {
|
||||
protected:
|
||||
using ValueExpr::ValueExpr;
|
||||
std::string kind() const override;
|
||||
ValueRank rank() const override;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT ScalarExpr : public ValueExpr {
|
||||
protected:
|
||||
using ValueExpr::ValueExpr;
|
||||
std::string kind() const override;
|
||||
ValueRank rank() const override;
|
||||
};
|
||||
|
||||
namespace value {
|
||||
|
||||
// These are mixin classes to provide a type hierarchy for values identify
|
||||
class ValueMixin {};
|
||||
class Null : public ValueMixin {};
|
||||
class Bool : public ValueMixin {};
|
||||
class Number : public ValueMixin {};
|
||||
class Integer : public Number {};
|
||||
class SignedInteger : public Integer {};
|
||||
class Int8 : public SignedInteger {};
|
||||
class Int16 : public SignedInteger {};
|
||||
class Int32 : public SignedInteger {};
|
||||
class Int64 : public SignedInteger {};
|
||||
class UnsignedInteger : public Integer {};
|
||||
class UInt8 : public UnsignedInteger {};
|
||||
class UInt16 : public UnsignedInteger {};
|
||||
class UInt32 : public UnsignedInteger {};
|
||||
class UInt64 : public UnsignedInteger {};
|
||||
class Floating : public Number {};
|
||||
class Float16 : public Floating {};
|
||||
class Float32 : public Floating {};
|
||||
class Float64 : public Floating {};
|
||||
class Binary : public ValueMixin {};
|
||||
class Utf8 : public Binary {};
|
||||
class List : public ValueMixin {};
|
||||
class Struct : public ValueMixin {};
|
||||
|
||||
} // namespace value
|
||||
|
||||
#define SIMPLE_EXPR_FACTORY(NAME) ARROW_EXPORT ExprPtr NAME(ConstOpPtr op);
|
||||
|
||||
namespace scalar {
|
||||
|
||||
#define DECLARE_SCALAR_EXPR(TYPE) \
|
||||
class ARROW_EXPORT TYPE : public ScalarExpr, public value::TYPE { \
|
||||
public: \
|
||||
explicit TYPE(ConstOpPtr op); \
|
||||
using ScalarExpr::kind; \
|
||||
};
|
||||
|
||||
DECLARE_SCALAR_EXPR(Null)
|
||||
DECLARE_SCALAR_EXPR(Bool)
|
||||
DECLARE_SCALAR_EXPR(Int8)
|
||||
DECLARE_SCALAR_EXPR(Int16)
|
||||
DECLARE_SCALAR_EXPR(Int32)
|
||||
DECLARE_SCALAR_EXPR(Int64)
|
||||
DECLARE_SCALAR_EXPR(UInt8)
|
||||
DECLARE_SCALAR_EXPR(UInt16)
|
||||
DECLARE_SCALAR_EXPR(UInt32)
|
||||
DECLARE_SCALAR_EXPR(UInt64)
|
||||
DECLARE_SCALAR_EXPR(Float16)
|
||||
DECLARE_SCALAR_EXPR(Float32)
|
||||
DECLARE_SCALAR_EXPR(Float64)
|
||||
DECLARE_SCALAR_EXPR(Binary)
|
||||
DECLARE_SCALAR_EXPR(Utf8)
|
||||
|
||||
#undef DECLARE_SCALAR_EXPR
|
||||
|
||||
SIMPLE_EXPR_FACTORY(null);
|
||||
SIMPLE_EXPR_FACTORY(boolean);
|
||||
SIMPLE_EXPR_FACTORY(int8);
|
||||
SIMPLE_EXPR_FACTORY(int16);
|
||||
SIMPLE_EXPR_FACTORY(int32);
|
||||
SIMPLE_EXPR_FACTORY(int64);
|
||||
SIMPLE_EXPR_FACTORY(uint8);
|
||||
SIMPLE_EXPR_FACTORY(uint16);
|
||||
SIMPLE_EXPR_FACTORY(uint32);
|
||||
SIMPLE_EXPR_FACTORY(uint64);
|
||||
SIMPLE_EXPR_FACTORY(float16);
|
||||
SIMPLE_EXPR_FACTORY(float32);
|
||||
SIMPLE_EXPR_FACTORY(float64);
|
||||
SIMPLE_EXPR_FACTORY(binary);
|
||||
SIMPLE_EXPR_FACTORY(utf8);
|
||||
|
||||
class ARROW_EXPORT List : public ScalarExpr, public value::List {
|
||||
public:
|
||||
List(ConstOpPtr op, LogicalTypePtr type);
|
||||
using ScalarExpr::kind;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT Struct : public ScalarExpr, public value::Struct {
|
||||
public:
|
||||
Struct(ConstOpPtr op, LogicalTypePtr type);
|
||||
using ScalarExpr::kind;
|
||||
};
|
||||
|
||||
} // namespace scalar
|
||||
|
||||
namespace array {
|
||||
|
||||
#define DECLARE_ARRAY_EXPR(TYPE) \
|
||||
class ARROW_EXPORT TYPE : public ArrayExpr, public value::TYPE { \
|
||||
public: \
|
||||
explicit TYPE(ConstOpPtr op); \
|
||||
using ArrayExpr::kind; \
|
||||
};
|
||||
|
||||
DECLARE_ARRAY_EXPR(Null)
|
||||
DECLARE_ARRAY_EXPR(Bool)
|
||||
DECLARE_ARRAY_EXPR(Int8)
|
||||
DECLARE_ARRAY_EXPR(Int16)
|
||||
DECLARE_ARRAY_EXPR(Int32)
|
||||
DECLARE_ARRAY_EXPR(Int64)
|
||||
DECLARE_ARRAY_EXPR(UInt8)
|
||||
DECLARE_ARRAY_EXPR(UInt16)
|
||||
DECLARE_ARRAY_EXPR(UInt32)
|
||||
DECLARE_ARRAY_EXPR(UInt64)
|
||||
DECLARE_ARRAY_EXPR(Float16)
|
||||
DECLARE_ARRAY_EXPR(Float32)
|
||||
DECLARE_ARRAY_EXPR(Float64)
|
||||
DECLARE_ARRAY_EXPR(Binary)
|
||||
DECLARE_ARRAY_EXPR(Utf8)
|
||||
|
||||
#undef DECLARE_ARRAY_EXPR
|
||||
|
||||
SIMPLE_EXPR_FACTORY(null);
|
||||
SIMPLE_EXPR_FACTORY(boolean);
|
||||
SIMPLE_EXPR_FACTORY(int8);
|
||||
SIMPLE_EXPR_FACTORY(int16);
|
||||
SIMPLE_EXPR_FACTORY(int32);
|
||||
SIMPLE_EXPR_FACTORY(int64);
|
||||
SIMPLE_EXPR_FACTORY(uint8);
|
||||
SIMPLE_EXPR_FACTORY(uint16);
|
||||
SIMPLE_EXPR_FACTORY(uint32);
|
||||
SIMPLE_EXPR_FACTORY(uint64);
|
||||
SIMPLE_EXPR_FACTORY(float16);
|
||||
SIMPLE_EXPR_FACTORY(float32);
|
||||
SIMPLE_EXPR_FACTORY(float64);
|
||||
SIMPLE_EXPR_FACTORY(binary);
|
||||
SIMPLE_EXPR_FACTORY(utf8);
|
||||
|
||||
class ARROW_EXPORT List : public ArrayExpr, public value::List {
|
||||
public:
|
||||
List(ConstOpPtr op, LogicalTypePtr type);
|
||||
using ArrayExpr::kind;
|
||||
};
|
||||
|
||||
class ARROW_EXPORT Struct : public ArrayExpr, public value::Struct {
|
||||
public:
|
||||
Struct(ConstOpPtr op, LogicalTypePtr type);
|
||||
using ArrayExpr::kind;
|
||||
};
|
||||
|
||||
} // namespace array
|
||||
|
||||
#undef SIMPLE_EXPR_FACTORY
|
||||
|
||||
template <typename T, typename ObjectType>
|
||||
inline bool InheritsFrom(const ObjectType* obj) {
|
||||
return dynamic_cast<const T*>(obj) != NULLPTR;
|
||||
}
|
||||
|
||||
template <typename T, typename ObjectType>
|
||||
inline bool InheritsFrom(const ObjectType& obj) {
|
||||
return dynamic_cast<const T*>(&obj) != NULLPTR;
|
||||
}
|
||||
|
||||
/// \brief Construct a ScalarExpr containing an Operation given a logical type
|
||||
ARROW_EXPORT
|
||||
Status GetScalarExpr(ConstOpPtr op, LogicalTypePtr ty, ExprPtr* out);
|
||||
|
||||
/// \brief Construct an ArrayExpr containing an Operation given a logical type
|
||||
ARROW_EXPORT
|
||||
Status GetArrayExpr(ConstOpPtr op, LogicalTypePtr ty, ExprPtr* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,289 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_KERNEL_H
|
||||
#define ARROW_COMPUTE_KERNEL_H
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/record_batch.h"
|
||||
#include "arrow/scalar.h"
|
||||
#include "arrow/table.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/memory.h"
|
||||
#include "arrow/util/variant.h" // IWYU pragma: export
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace compute {
|
||||
|
||||
class FunctionContext;
|
||||
|
||||
/// \class OpKernel
|
||||
/// \brief Base class for operator kernels
|
||||
///
|
||||
/// Note to implementors:
|
||||
/// Operator kernels are intended to be the lowest level of an analytics/compute
|
||||
/// engine. They will generally not be exposed directly to end-users. Instead
|
||||
/// they will be wrapped by higher level constructs (e.g. top-level functions
|
||||
/// or physical execution plan nodes). These higher level constructs are
|
||||
/// responsible for user input validation and returning the appropriate
|
||||
/// error Status.
|
||||
///
|
||||
/// Due to this design, implementations of Call (the execution
|
||||
/// method on subclasses) should use assertions (i.e. DCHECK) to double-check
|
||||
/// parameter arguments when in higher level components returning an
|
||||
/// InvalidArgument error might be more appropriate.
|
||||
///
|
||||
class ARROW_EXPORT OpKernel {
|
||||
public:
|
||||
virtual ~OpKernel() = default;
|
||||
/// \brief EXPERIMENTAL The output data type of the kernel
|
||||
/// \return the output type
|
||||
virtual std::shared_ptr<DataType> out_type() const = 0;
|
||||
};
|
||||
|
||||
struct Datum;
|
||||
static inline bool CollectionEquals(const std::vector<Datum>& left,
|
||||
const std::vector<Datum>& right);
|
||||
|
||||
// Datums variants may have a length. This special value indicate that the
|
||||
// current variant does not have a length.
|
||||
constexpr int64_t kUnknownLength = -1;
|
||||
|
||||
/// \class Datum
|
||||
/// \brief Variant type for various Arrow C++ data structures
|
||||
struct ARROW_EXPORT Datum {
|
||||
enum type { NONE, SCALAR, ARRAY, CHUNKED_ARRAY, RECORD_BATCH, TABLE, COLLECTION };
|
||||
|
||||
util::variant<decltype(NULLPTR), std::shared_ptr<Scalar>, std::shared_ptr<ArrayData>,
|
||||
std::shared_ptr<ChunkedArray>, std::shared_ptr<RecordBatch>,
|
||||
std::shared_ptr<Table>, std::vector<Datum>>
|
||||
value;
|
||||
|
||||
/// \brief Empty datum, to be populated elsewhere
|
||||
Datum() : value(NULLPTR) {}
|
||||
|
||||
Datum(const std::shared_ptr<Scalar>& value) // NOLINT implicit conversion
|
||||
: value(value) {}
|
||||
Datum(const std::shared_ptr<ArrayData>& value) // NOLINT implicit conversion
|
||||
: value(value) {}
|
||||
|
||||
Datum(const std::shared_ptr<Array>& value) // NOLINT implicit conversion
|
||||
: Datum(value ? value->data() : NULLPTR) {}
|
||||
|
||||
Datum(const std::shared_ptr<ChunkedArray>& value) // NOLINT implicit conversion
|
||||
: value(value) {}
|
||||
Datum(const std::shared_ptr<RecordBatch>& value) // NOLINT implicit conversion
|
||||
: value(value) {}
|
||||
Datum(const std::shared_ptr<Table>& value) // NOLINT implicit conversion
|
||||
: value(value) {}
|
||||
Datum(const std::vector<Datum>& value) // NOLINT implicit conversion
|
||||
: value(value) {}
|
||||
|
||||
// Cast from subtypes of Array to Datum
|
||||
template <typename T,
|
||||
typename = typename std::enable_if<std::is_base_of<Array, T>::value>::type>
|
||||
Datum(const std::shared_ptr<T>& value) // NOLINT implicit conversion
|
||||
: Datum(std::shared_ptr<Array>(value)) {}
|
||||
|
||||
// Convenience constructors
|
||||
explicit Datum(bool value) : value(std::make_shared<BooleanScalar>(value)) {}
|
||||
explicit Datum(int8_t value) : value(std::make_shared<Int8Scalar>(value)) {}
|
||||
explicit Datum(uint8_t value) : value(std::make_shared<UInt8Scalar>(value)) {}
|
||||
explicit Datum(int16_t value) : value(std::make_shared<Int16Scalar>(value)) {}
|
||||
explicit Datum(uint16_t value) : value(std::make_shared<UInt16Scalar>(value)) {}
|
||||
explicit Datum(int32_t value) : value(std::make_shared<Int32Scalar>(value)) {}
|
||||
explicit Datum(uint32_t value) : value(std::make_shared<UInt32Scalar>(value)) {}
|
||||
explicit Datum(int64_t value) : value(std::make_shared<Int64Scalar>(value)) {}
|
||||
explicit Datum(uint64_t value) : value(std::make_shared<UInt64Scalar>(value)) {}
|
||||
explicit Datum(float value) : value(std::make_shared<FloatScalar>(value)) {}
|
||||
explicit Datum(double value) : value(std::make_shared<DoubleScalar>(value)) {}
|
||||
|
||||
~Datum() {}
|
||||
|
||||
Datum(const Datum& other) noexcept { this->value = other.value; }
|
||||
|
||||
Datum& operator=(const Datum& other) noexcept {
|
||||
value = other.value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
// Define move constructor and move assignment, for better performance
|
||||
Datum(Datum&& other) noexcept : value(std::move(other.value)) {}
|
||||
|
||||
Datum& operator=(Datum&& other) noexcept {
|
||||
value = std::move(other.value);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Datum::type kind() const {
|
||||
switch (this->value.index()) {
|
||||
case 0:
|
||||
return Datum::NONE;
|
||||
case 1:
|
||||
return Datum::SCALAR;
|
||||
case 2:
|
||||
return Datum::ARRAY;
|
||||
case 3:
|
||||
return Datum::CHUNKED_ARRAY;
|
||||
case 4:
|
||||
return Datum::RECORD_BATCH;
|
||||
case 5:
|
||||
return Datum::TABLE;
|
||||
case 6:
|
||||
return Datum::COLLECTION;
|
||||
default:
|
||||
return Datum::NONE;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ArrayData> array() const {
|
||||
return util::get<std::shared_ptr<ArrayData>>(this->value);
|
||||
}
|
||||
|
||||
std::shared_ptr<Array> make_array() const {
|
||||
return MakeArray(util::get<std::shared_ptr<ArrayData>>(this->value));
|
||||
}
|
||||
|
||||
std::shared_ptr<ChunkedArray> chunked_array() const {
|
||||
return util::get<std::shared_ptr<ChunkedArray>>(this->value);
|
||||
}
|
||||
|
||||
std::shared_ptr<RecordBatch> record_batch() const {
|
||||
return util::get<std::shared_ptr<RecordBatch>>(this->value);
|
||||
}
|
||||
|
||||
std::shared_ptr<Table> table() const {
|
||||
return util::get<std::shared_ptr<Table>>(this->value);
|
||||
}
|
||||
|
||||
const std::vector<Datum> collection() const {
|
||||
return util::get<std::vector<Datum>>(this->value);
|
||||
}
|
||||
|
||||
std::shared_ptr<Scalar> scalar() const {
|
||||
return util::get<std::shared_ptr<Scalar>>(this->value);
|
||||
}
|
||||
|
||||
bool is_array() const { return this->kind() == Datum::ARRAY; }
|
||||
|
||||
bool is_arraylike() const {
|
||||
return this->kind() == Datum::ARRAY || this->kind() == Datum::CHUNKED_ARRAY;
|
||||
}
|
||||
|
||||
bool is_scalar() const { return this->kind() == Datum::SCALAR; }
|
||||
|
||||
/// \brief The value type of the variant, if any
|
||||
///
|
||||
/// \return nullptr if no type
|
||||
std::shared_ptr<DataType> type() const {
|
||||
if (this->kind() == Datum::ARRAY) {
|
||||
return util::get<std::shared_ptr<ArrayData>>(this->value)->type;
|
||||
} else if (this->kind() == Datum::CHUNKED_ARRAY) {
|
||||
return util::get<std::shared_ptr<ChunkedArray>>(this->value)->type();
|
||||
} else if (this->kind() == Datum::SCALAR) {
|
||||
return util::get<std::shared_ptr<Scalar>>(this->value)->type;
|
||||
}
|
||||
return NULLPTR;
|
||||
}
|
||||
|
||||
/// \brief The value length of the variant, if any
|
||||
///
|
||||
/// \return kUnknownLength if no type
|
||||
int64_t length() const {
|
||||
if (this->kind() == Datum::ARRAY) {
|
||||
return util::get<std::shared_ptr<ArrayData>>(this->value)->length;
|
||||
} else if (this->kind() == Datum::CHUNKED_ARRAY) {
|
||||
return util::get<std::shared_ptr<ChunkedArray>>(this->value)->length();
|
||||
} else if (this->kind() == Datum::SCALAR) {
|
||||
return 1;
|
||||
}
|
||||
return kUnknownLength;
|
||||
}
|
||||
|
||||
bool Equals(const Datum& other) const {
|
||||
if (this->kind() != other.kind()) return false;
|
||||
|
||||
switch (this->kind()) {
|
||||
case Datum::NONE:
|
||||
return true;
|
||||
case Datum::SCALAR:
|
||||
return internal::SharedPtrEquals(this->scalar(), other.scalar());
|
||||
case Datum::ARRAY:
|
||||
return internal::SharedPtrEquals(this->make_array(), other.make_array());
|
||||
case Datum::CHUNKED_ARRAY:
|
||||
return internal::SharedPtrEquals(this->chunked_array(), other.chunked_array());
|
||||
case Datum::RECORD_BATCH:
|
||||
return internal::SharedPtrEquals(this->record_batch(), other.record_batch());
|
||||
case Datum::TABLE:
|
||||
return internal::SharedPtrEquals(this->table(), other.table());
|
||||
case Datum::COLLECTION:
|
||||
return CollectionEquals(this->collection(), other.collection());
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/// \class UnaryKernel
|
||||
/// \brief An array-valued function of a single input argument.
|
||||
///
|
||||
/// Note to implementors: Try to avoid making kernels that allocate memory if
|
||||
/// the output size is a deterministic function of the Input Datum's metadata.
|
||||
/// Instead separate the logic of the kernel and allocations necessary into
|
||||
/// two different kernels. Some reusable kernels that allocate buffers
|
||||
/// and delegate computation to another kernel are available in util-internal.h.
|
||||
class ARROW_EXPORT UnaryKernel : public OpKernel {
|
||||
public:
|
||||
/// \brief Executes the kernel.
|
||||
///
|
||||
/// \param[in] ctx The function context for the kernel
|
||||
/// \param[in] input The kernel input data
|
||||
/// \param[out] out The output of the function. Each implementation of this
|
||||
/// function might assume different things about the existing contents of out
|
||||
/// (e.g. which buffers are preallocated). In the future it is expected that
|
||||
/// there will be a more generic mechansim for understanding the necessary
|
||||
/// contracts.
|
||||
virtual Status Call(FunctionContext* ctx, const Datum& input, Datum* out) = 0;
|
||||
};
|
||||
|
||||
/// \class BinaryKernel
|
||||
/// \brief An array-valued function of a two input arguments
|
||||
class ARROW_EXPORT BinaryKernel : public OpKernel {
|
||||
public:
|
||||
virtual Status Call(FunctionContext* ctx, const Datum& left, const Datum& right,
|
||||
Datum* out) = 0;
|
||||
};
|
||||
|
||||
static inline bool CollectionEquals(const std::vector<Datum>& left,
|
||||
const std::vector<Datum>& right) {
|
||||
if (left.size() != right.size()) return false;
|
||||
|
||||
for (size_t i = 0; i < left.size(); i++)
|
||||
if (!left[i].Equals(right[i])) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_COMPUTE_KERNEL_H
|
@ -1,115 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/compute/kernel.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class Status;
|
||||
|
||||
namespace compute {
|
||||
|
||||
class FunctionContext;
|
||||
struct Datum;
|
||||
|
||||
/// AggregateFunction is an interface for Aggregates
|
||||
///
|
||||
/// An aggregates transforms an array into single result called a state via the
|
||||
/// Consume method.. State supports the merge operation via the Merge method.
|
||||
/// State can be sealed into a final result via the Finalize method.
|
||||
//
|
||||
/// State ownership is handled by callers, thus the interface exposes 3 methods
|
||||
/// for the caller to manage memory:
|
||||
/// - Size
|
||||
/// - New (placement new constructor invocation)
|
||||
/// - Delete (state desctructor)
|
||||
///
|
||||
/// Design inspired by ClickHouse aggregate functions.
|
||||
class AggregateFunction {
|
||||
public:
|
||||
/// \brief Consume an array into a state.
|
||||
virtual Status Consume(const Array& input, void* state) const = 0;
|
||||
|
||||
/// \brief Merge states.
|
||||
virtual Status Merge(const void* src, void* dst) const = 0;
|
||||
|
||||
/// \brief Convert state into a final result.
|
||||
virtual Status Finalize(const void* src, Datum* output) const = 0;
|
||||
|
||||
virtual ~AggregateFunction() {}
|
||||
|
||||
virtual std::shared_ptr<DataType> out_type() const = 0;
|
||||
|
||||
/// State management methods.
|
||||
virtual int64_t Size() const = 0;
|
||||
virtual void New(void* ptr) const = 0;
|
||||
virtual void Delete(void* ptr) const = 0;
|
||||
};
|
||||
|
||||
/// AggregateFunction partial implementation for static type state
|
||||
template <typename State>
|
||||
class AggregateFunctionStaticState : public AggregateFunction {
|
||||
virtual Status Consume(const Array& input, State* state) const = 0;
|
||||
virtual Status Merge(const State& src, State* dst) const = 0;
|
||||
virtual Status Finalize(const State& src, Datum* output) const = 0;
|
||||
|
||||
Status Consume(const Array& input, void* state) const final {
|
||||
return Consume(input, static_cast<State*>(state));
|
||||
}
|
||||
|
||||
Status Merge(const void* src, void* dst) const final {
|
||||
return Merge(*static_cast<const State*>(src), static_cast<State*>(dst));
|
||||
}
|
||||
|
||||
/// \brief Convert state into a final result.
|
||||
Status Finalize(const void* src, Datum* output) const final {
|
||||
return Finalize(*static_cast<const State*>(src), output);
|
||||
}
|
||||
|
||||
int64_t Size() const final { return sizeof(State); }
|
||||
|
||||
void New(void* ptr) const final {
|
||||
// By using placement-new syntax, the constructor of the State is invoked
|
||||
// in the memory location defined by the caller. This only supports State
|
||||
// with a parameter-less constructor.
|
||||
new (ptr) State;
|
||||
}
|
||||
|
||||
void Delete(void* ptr) const final { static_cast<State*>(ptr)->~State(); }
|
||||
};
|
||||
|
||||
/// \brief UnaryKernel implemented by an AggregateState
|
||||
class ARROW_EXPORT AggregateUnaryKernel : public UnaryKernel {
|
||||
public:
|
||||
explicit AggregateUnaryKernel(std::shared_ptr<AggregateFunction>& aggregate)
|
||||
: aggregate_function_(aggregate) {}
|
||||
|
||||
Status Call(FunctionContext* ctx, const Datum& input, Datum* out) override;
|
||||
|
||||
std::shared_ptr<DataType> out_type() const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<AggregateFunction> aggregate_function_;
|
||||
};
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,76 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_KERNELS_BOOLEAN_H
|
||||
#define ARROW_COMPUTE_KERNELS_BOOLEAN_H
|
||||
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace compute {
|
||||
|
||||
struct Datum;
|
||||
class FunctionContext;
|
||||
|
||||
/// \brief Invert the values of a boolean datum
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] value datum to invert
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.11.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Invert(FunctionContext* context, const Datum& value, Datum* out);
|
||||
|
||||
/// \brief Element-wise AND of two boolean datums
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] left left operand (array)
|
||||
/// \param[in] right right operand (array)
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.11.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status And(FunctionContext* context, const Datum& left, const Datum& right, Datum* out);
|
||||
|
||||
/// \brief Element-wise OR of two boolean datums
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] left left operand (array)
|
||||
/// \param[in] right right operand (array)
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.11.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Or(FunctionContext* context, const Datum& left, const Datum& right, Datum* out);
|
||||
|
||||
/// \brief Element-wise XOR of two boolean datums
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] left left operand (array)
|
||||
/// \param[in] right right operand (array)
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.11.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Xor(FunctionContext* context, const Datum& left, const Datum& right, Datum* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_COMPUTE_KERNELS_CAST_H
|
@ -1,98 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_KERNELS_CAST_H
|
||||
#define ARROW_COMPUTE_KERNELS_CAST_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
|
||||
namespace compute {
|
||||
|
||||
struct Datum;
|
||||
class FunctionContext;
|
||||
class UnaryKernel;
|
||||
|
||||
struct ARROW_EXPORT CastOptions {
|
||||
CastOptions()
|
||||
: allow_int_overflow(false),
|
||||
allow_time_truncate(false),
|
||||
allow_float_truncate(false),
|
||||
allow_invalid_utf8(false) {}
|
||||
|
||||
explicit CastOptions(bool safe)
|
||||
: allow_int_overflow(!safe),
|
||||
allow_time_truncate(!safe),
|
||||
allow_float_truncate(!safe),
|
||||
allow_invalid_utf8(!safe) {}
|
||||
|
||||
static CastOptions Safe() { return CastOptions(true); }
|
||||
|
||||
static CastOptions Unsafe() { return CastOptions(false); }
|
||||
|
||||
bool allow_int_overflow;
|
||||
bool allow_time_truncate;
|
||||
bool allow_float_truncate;
|
||||
// Indicate if conversions from Binary/FixedSizeBinary to string must
|
||||
// validate the utf8 payload.
|
||||
bool allow_invalid_utf8;
|
||||
};
|
||||
|
||||
/// \since 0.7.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> to_type,
|
||||
const CastOptions& options, std::unique_ptr<UnaryKernel>* kernel);
|
||||
|
||||
/// \brief Cast from one array type to another
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] value array to cast
|
||||
/// \param[in] to_type type to cast to
|
||||
/// \param[in] options casting options
|
||||
/// \param[out] out resulting array
|
||||
///
|
||||
/// \since 0.7.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Cast(FunctionContext* context, const Array& value,
|
||||
std::shared_ptr<DataType> to_type, const CastOptions& options,
|
||||
std::shared_ptr<Array>* out);
|
||||
|
||||
/// \brief Cast from one value to another
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] value datum to cast
|
||||
/// \param[in] to_type type to cast to
|
||||
/// \param[in] options casting options
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.8.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Cast(FunctionContext* context, const Datum& value,
|
||||
std::shared_ptr<DataType> to_type, const CastOptions& options, Datum* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_COMPUTE_KERNELS_CAST_H
|
@ -1,176 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/compute/kernel.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
struct Scalar;
|
||||
class Status;
|
||||
|
||||
namespace compute {
|
||||
|
||||
struct Datum;
|
||||
class FunctionContext;
|
||||
|
||||
/// CompareFunction is an interface for Comparisons
|
||||
///
|
||||
/// Comparisons take an array and emits a selection vector. The selection vector
|
||||
/// is given in the form of a bitmask as a BooleanArray result.
|
||||
class ARROW_EXPORT CompareFunction {
|
||||
public:
|
||||
/// Compare an array with a scalar argument.
|
||||
virtual Status Compare(const ArrayData& array, const Scalar& scalar,
|
||||
ArrayData* output) const = 0;
|
||||
|
||||
Status Compare(const ArrayData& array, const Scalar& scalar,
|
||||
std::shared_ptr<ArrayData>* output) {
|
||||
return Compare(array, scalar, output->get());
|
||||
}
|
||||
|
||||
virtual Status Compare(const Scalar& scalar, const ArrayData& array,
|
||||
ArrayData* output) const = 0;
|
||||
|
||||
Status Compare(const Scalar& scalar, const ArrayData& array,
|
||||
std::shared_ptr<ArrayData>* output) {
|
||||
return Compare(scalar, array, output->get());
|
||||
}
|
||||
|
||||
/// Compare an array with an array argument.
|
||||
virtual Status Compare(const ArrayData& lhs, const ArrayData& rhs,
|
||||
ArrayData* output) const = 0;
|
||||
|
||||
Status Compare(const ArrayData& lhs, const ArrayData& rhs,
|
||||
std::shared_ptr<ArrayData>* output) {
|
||||
return Compare(lhs, rhs, output->get());
|
||||
}
|
||||
|
||||
/// By default, CompareFunction emits a result bitmap.
|
||||
virtual std::shared_ptr<DataType> out_type() const { return boolean(); }
|
||||
|
||||
virtual ~CompareFunction() {}
|
||||
};
|
||||
|
||||
/// \brief BinaryKernel bound to a select function
|
||||
class ARROW_EXPORT CompareBinaryKernel : public BinaryKernel {
|
||||
public:
|
||||
explicit CompareBinaryKernel(std::shared_ptr<CompareFunction>& select)
|
||||
: compare_function_(select) {}
|
||||
|
||||
Status Call(FunctionContext* ctx, const Datum& left, const Datum& right,
|
||||
Datum* out) override;
|
||||
|
||||
static int64_t out_length(const Datum& left, const Datum& right) {
|
||||
if (left.kind() == Datum::ARRAY) return left.length();
|
||||
if (right.kind() == Datum::ARRAY) return right.length();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::shared_ptr<DataType> out_type() const override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<CompareFunction> compare_function_;
|
||||
};
|
||||
|
||||
enum CompareOperator {
|
||||
EQUAL,
|
||||
NOT_EQUAL,
|
||||
GREATER,
|
||||
GREATER_EQUAL,
|
||||
LESS,
|
||||
LESS_EQUAL,
|
||||
};
|
||||
|
||||
template <typename T, CompareOperator Op>
|
||||
struct Comparator;
|
||||
|
||||
template <typename T>
|
||||
struct Comparator<T, CompareOperator::EQUAL> {
|
||||
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs == rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Comparator<T, CompareOperator::NOT_EQUAL> {
|
||||
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs != rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Comparator<T, CompareOperator::GREATER> {
|
||||
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs > rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Comparator<T, CompareOperator::GREATER_EQUAL> {
|
||||
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs >= rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Comparator<T, CompareOperator::LESS> {
|
||||
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs < rhs; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Comparator<T, CompareOperator::LESS_EQUAL> {
|
||||
constexpr static bool Compare(const T& lhs, const T& rhs) { return lhs <= rhs; }
|
||||
};
|
||||
|
||||
struct CompareOptions {
|
||||
explicit CompareOptions(CompareOperator op) : op(op) {}
|
||||
|
||||
enum CompareOperator op;
|
||||
};
|
||||
|
||||
/// \brief Return a Compare CompareFunction
|
||||
///
|
||||
/// \param[in] context FunctionContext passing context information
|
||||
/// \param[in] type required to specialize the kernel
|
||||
/// \param[in] options required to specify the compare operator
|
||||
///
|
||||
/// \since 0.14.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
std::shared_ptr<CompareFunction> MakeCompareFunction(FunctionContext* context,
|
||||
const DataType& type,
|
||||
struct CompareOptions options);
|
||||
|
||||
/// \brief Compare a numeric array with a scalar.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] left datum to compare, must be an Array
|
||||
/// \param[in] right datum to compare, must be a Scalar of the same type than
|
||||
/// left Datum.
|
||||
/// \param[in] options compare options
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// Note on floating point arrays, this uses ieee-754 compare semantics.
|
||||
///
|
||||
/// \since 0.14.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Compare(FunctionContext* context, const Datum& left, const Datum& right,
|
||||
struct CompareOptions options, Datum* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,88 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/type.h"
|
||||
#include "arrow/type_traits.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
|
||||
namespace compute {
|
||||
|
||||
struct Datum;
|
||||
class FunctionContext;
|
||||
class AggregateFunction;
|
||||
|
||||
/// \class CountOptions
|
||||
///
|
||||
/// The user control the Count kernel behavior with this class. By default, the
|
||||
/// it will count all non-null values.
|
||||
struct ARROW_EXPORT CountOptions {
|
||||
enum mode {
|
||||
// Count all non-null values.
|
||||
COUNT_ALL = 0,
|
||||
// Count all null values.
|
||||
COUNT_NULL,
|
||||
};
|
||||
|
||||
explicit CountOptions(enum mode count_mode) : count_mode(count_mode) {}
|
||||
|
||||
enum mode count_mode = COUNT_ALL;
|
||||
};
|
||||
|
||||
/// \brief Return Count function aggregate
|
||||
ARROW_EXPORT
|
||||
std::shared_ptr<AggregateFunction> MakeCount(FunctionContext* context,
|
||||
const CountOptions& options);
|
||||
|
||||
/// \brief Count non-null (or null) values in an array.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] options counting options, see CountOptions for more information
|
||||
/// \param[in] datum to count
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Count(FunctionContext* context, const CountOptions& options, const Datum& datum,
|
||||
Datum* out);
|
||||
|
||||
/// \brief Count non-null (or null) values in an array.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] options counting options, see CountOptions for more information
|
||||
/// \param[in] array to count
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Count(FunctionContext* context, const CountOptions& options, const Array& array,
|
||||
Datum* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,93 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/compute/kernel.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
|
||||
namespace compute {
|
||||
|
||||
class FunctionContext;
|
||||
|
||||
/// \brief Filter an array with a boolean selection filter
|
||||
///
|
||||
/// The output array will be populated with values from the input at positions
|
||||
/// where the selection filter is not 0. Nulls in the filter will result in nulls
|
||||
/// in the output.
|
||||
///
|
||||
/// For example given values = ["a", "b", "c", null, "e", "f"] and
|
||||
/// filter = [0, 1, 1, 0, null, 1], the output will be
|
||||
/// = ["b", "c", null, "f"]
|
||||
///
|
||||
/// \param[in] ctx the FunctionContext
|
||||
/// \param[in] values array to filter
|
||||
/// \param[in] filter indicates which values should be filtered out
|
||||
/// \param[out] out resulting array
|
||||
ARROW_EXPORT
|
||||
Status Filter(FunctionContext* ctx, const Array& values, const Array& filter,
|
||||
std::shared_ptr<Array>* out);
|
||||
|
||||
/// \brief Filter an array with a boolean selection filter
|
||||
///
|
||||
/// \param[in] ctx the FunctionContext
|
||||
/// \param[in] values datum to filter
|
||||
/// \param[in] filter indicates which values should be filtered out
|
||||
/// \param[out] out resulting datum
|
||||
ARROW_EXPORT
|
||||
Status Filter(FunctionContext* ctx, const Datum& values, const Datum& filter, Datum* out);
|
||||
|
||||
/// \brief BinaryKernel implementing Filter operation
|
||||
class ARROW_EXPORT FilterKernel : public BinaryKernel {
|
||||
public:
|
||||
explicit FilterKernel(const std::shared_ptr<DataType>& type) : type_(type) {}
|
||||
|
||||
/// \brief BinaryKernel interface
|
||||
///
|
||||
/// delegates to subclasses via Filter()
|
||||
Status Call(FunctionContext* ctx, const Datum& values, const Datum& filter,
|
||||
Datum* out) override;
|
||||
|
||||
/// \brief output type of this kernel (identical to type of values filtered)
|
||||
std::shared_ptr<DataType> out_type() const override { return type_; }
|
||||
|
||||
/// \brief factory for FilterKernels
|
||||
///
|
||||
/// \param[in] value_type constructed FilterKernel will support filtering
|
||||
/// values of this type
|
||||
/// \param[out] out created kernel
|
||||
static Status Make(const std::shared_ptr<DataType>& value_type,
|
||||
std::unique_ptr<FilterKernel>* out);
|
||||
|
||||
/// \brief single-array implementation
|
||||
virtual Status Filter(FunctionContext* ctx, const Array& values,
|
||||
const BooleanArray& filter, int64_t length,
|
||||
std::shared_ptr<Array>* out) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<DataType> type_;
|
||||
};
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,105 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_KERNELS_HASH_H
|
||||
#define ARROW_COMPUTE_KERNELS_HASH_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/compute/kernel.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
struct ArrayData;
|
||||
|
||||
namespace compute {
|
||||
|
||||
class FunctionContext;
|
||||
|
||||
/// \brief Compute unique elements from an array-like object
|
||||
///
|
||||
/// Note if a null occurs in the input it will NOT be included in the output.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] datum array-like input
|
||||
/// \param[out] out result as Array
|
||||
///
|
||||
/// \since 0.8.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Unique(FunctionContext* context, const Datum& datum, std::shared_ptr<Array>* out);
|
||||
|
||||
// Constants for accessing the output of ValueCounts
|
||||
ARROW_EXPORT extern const char kValuesFieldName[];
|
||||
ARROW_EXPORT extern const char kCountsFieldName[];
|
||||
ARROW_EXPORT extern const int32_t kValuesFieldIndex;
|
||||
ARROW_EXPORT extern const int32_t kCountsFieldIndex;
|
||||
/// \brief Return counts of unique elements from an array-like object.
|
||||
///
|
||||
/// Note that the counts do not include counts for nulls in the array. These can be
|
||||
/// obtained separately from metadata.
|
||||
///
|
||||
/// For floating point arrays there is no attempt to normalize -0.0, 0.0 and NaN values
|
||||
/// which can lead to unexpected results if the input Array has these values.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] value array-like input
|
||||
/// \param[out] counts An array of <input type "Values", int64_t "Counts"> structs.
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status ValueCounts(FunctionContext* context, const Datum& value,
|
||||
std::shared_ptr<Array>* counts);
|
||||
|
||||
/// \brief Dictionary-encode values in an array-like object
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] data array-like input
|
||||
/// \param[out] out result with same shape and type as input
|
||||
///
|
||||
/// \since 0.8.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status DictionaryEncode(FunctionContext* context, const Datum& data, Datum* out);
|
||||
|
||||
// TODO(wesm): Define API for incremental dictionary encoding
|
||||
|
||||
// TODO(wesm): Define API for regularizing DictionaryArray objects with
|
||||
// different dictionaries
|
||||
|
||||
//
|
||||
// ARROW_EXPORT
|
||||
// Status DictionaryEncode(FunctionContext* context, const Datum& data,
|
||||
// const Array& prior_dictionary, Datum* out);
|
||||
|
||||
// TODO(wesm): Implement these next
|
||||
// ARROW_EXPORT
|
||||
// Status Match(FunctionContext* context, const Datum& values, const Datum& member_set,
|
||||
// Datum* out);
|
||||
|
||||
// ARROW_EXPORT
|
||||
// Status IsIn(FunctionContext* context, const Datum& values, const Datum& member_set,
|
||||
// Datum* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_COMPUTE_KERNELS_HASH_H
|
@ -1,66 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <type_traits>
|
||||
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/type.h"
|
||||
#include "arrow/type_traits.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
|
||||
namespace compute {
|
||||
|
||||
struct Datum;
|
||||
class FunctionContext;
|
||||
class AggregateFunction;
|
||||
|
||||
ARROW_EXPORT
|
||||
std::shared_ptr<AggregateFunction> MakeMeanAggregateFunction(const DataType& type,
|
||||
FunctionContext* context);
|
||||
|
||||
/// \brief Compute the mean of a numeric array.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] value datum to compute the mean, expecting Array
|
||||
/// \param[out] mean datum of the computed mean as a DoubleScalar
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Mean(FunctionContext* context, const Datum& value, Datum* mean);
|
||||
|
||||
/// \brief Compute the mean of a numeric array.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] array to compute the mean
|
||||
/// \param[out] mean datum of the computed mean as a DoubleScalar
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Mean(FunctionContext* context, const Array& array, Datum* mean);
|
||||
|
||||
} // namespace compute
|
||||
}; // namespace arrow
|
@ -1,70 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
class Status;
|
||||
|
||||
namespace compute {
|
||||
|
||||
struct Datum;
|
||||
class FunctionContext;
|
||||
class AggregateFunction;
|
||||
|
||||
/// \brief Return a Sum Kernel
|
||||
///
|
||||
/// \param[in] type required to specialize the kernel
|
||||
/// \param[in] context the FunctionContext
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
std::shared_ptr<AggregateFunction> MakeSumAggregateFunction(const DataType& type,
|
||||
FunctionContext* context);
|
||||
|
||||
/// \brief Sum values of a numeric array.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] value datum to sum, expecting Array or ChunkedArray
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Sum(FunctionContext* context, const Datum& value, Datum* out);
|
||||
|
||||
/// \brief Sum values of a numeric array.
|
||||
///
|
||||
/// \param[in] context the FunctionContext
|
||||
/// \param[in] array to sum
|
||||
/// \param[out] out resulting datum
|
||||
///
|
||||
/// \since 0.13.0
|
||||
/// \note API not yet finalized
|
||||
ARROW_EXPORT
|
||||
Status Sum(FunctionContext* context, const Array& array, Datum* out);
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,101 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/compute/kernel.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
|
||||
namespace compute {
|
||||
|
||||
class FunctionContext;
|
||||
|
||||
struct ARROW_EXPORT TakeOptions {};
|
||||
|
||||
/// \brief Take from an array of values at indices in another array
|
||||
///
|
||||
/// The output array will be of the same type as the input values
|
||||
/// array, with elements taken from the values array at the given
|
||||
/// indices. If an index is null then the taken element will be null.
|
||||
///
|
||||
/// For example given values = ["a", "b", "c", null, "e", "f"] and
|
||||
/// indices = [2, 1, null, 3], the output will be
|
||||
/// = [values[2], values[1], null, values[3]]
|
||||
/// = ["c", "b", null, null]
|
||||
///
|
||||
/// \param[in] ctx the FunctionContext
|
||||
/// \param[in] values array from which to take
|
||||
/// \param[in] indices which values to take
|
||||
/// \param[in] options options
|
||||
/// \param[out] out resulting array
|
||||
ARROW_EXPORT
|
||||
Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
|
||||
const TakeOptions& options, std::shared_ptr<Array>* out);
|
||||
|
||||
/// \brief Take from an array of values at indices in another array
|
||||
///
|
||||
/// \param[in] ctx the FunctionContext
|
||||
/// \param[in] values datum from which to take
|
||||
/// \param[in] indices which values to take
|
||||
/// \param[in] options options
|
||||
/// \param[out] out resulting datum
|
||||
ARROW_EXPORT
|
||||
Status Take(FunctionContext* ctx, const Datum& values, const Datum& indices,
|
||||
const TakeOptions& options, Datum* out);
|
||||
|
||||
/// \brief BinaryKernel implementing Take operation
|
||||
class ARROW_EXPORT TakeKernel : public BinaryKernel {
|
||||
public:
|
||||
explicit TakeKernel(const std::shared_ptr<DataType>& type, TakeOptions options = {})
|
||||
: type_(type) {}
|
||||
|
||||
/// \brief BinaryKernel interface
|
||||
///
|
||||
/// delegates to subclasses via Take()
|
||||
Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices,
|
||||
Datum* out) override;
|
||||
|
||||
/// \brief output type of this kernel (identical to type of values taken)
|
||||
std::shared_ptr<DataType> out_type() const override { return type_; }
|
||||
|
||||
/// \brief factory for TakeKernels
|
||||
///
|
||||
/// \param[in] value_type constructed TakeKernel will support taking
|
||||
/// values of this type
|
||||
/// \param[in] index_type constructed TakeKernel will support taking
|
||||
/// with indices of this type
|
||||
/// \param[out] out created kernel
|
||||
static Status Make(const std::shared_ptr<DataType>& value_type,
|
||||
const std::shared_ptr<DataType>& index_type,
|
||||
std::unique_ptr<TakeKernel>* out);
|
||||
|
||||
/// \brief single-array implementation
|
||||
virtual Status Take(FunctionContext* ctx, const Array& values, const Array& indices,
|
||||
std::shared_ptr<Array>* out) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<DataType> type_;
|
||||
};
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,308 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
// Metadata objects for creating well-typed expressions. These are distinct
|
||||
// from (and higher level than) arrow::DataType as some type parameters (like
|
||||
// decimal scale and precision) may not be known at expression build time, and
|
||||
// these are resolved later on evaluation
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "arrow/compute/type_fwd.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Status;
|
||||
|
||||
namespace compute {
|
||||
|
||||
class Expr;
|
||||
|
||||
/// \brief An object that represents either a single concrete value type or a
|
||||
/// group of related types, to help with expression type validation and other
|
||||
/// purposes
|
||||
class ARROW_EXPORT LogicalType {
|
||||
public:
|
||||
enum Id {
|
||||
ANY,
|
||||
NUMBER,
|
||||
INTEGER,
|
||||
SIGNED_INTEGER,
|
||||
UNSIGNED_INTEGER,
|
||||
FLOATING,
|
||||
NULL_,
|
||||
BOOL,
|
||||
UINT8,
|
||||
INT8,
|
||||
UINT16,
|
||||
INT16,
|
||||
UINT32,
|
||||
INT32,
|
||||
UINT64,
|
||||
INT64,
|
||||
FLOAT16,
|
||||
FLOAT32,
|
||||
FLOAT64,
|
||||
BINARY,
|
||||
UTF8,
|
||||
DATE,
|
||||
TIME,
|
||||
TIMESTAMP,
|
||||
DECIMAL,
|
||||
LIST,
|
||||
STRUCT
|
||||
};
|
||||
|
||||
Id id() const { return id_; }
|
||||
|
||||
virtual ~LogicalType() = default;
|
||||
|
||||
virtual std::string ToString() const = 0;
|
||||
|
||||
/// \brief Check if expression is an instance of this type class
|
||||
virtual bool IsInstance(const Expr& expr) const = 0;
|
||||
|
||||
/// \brief Get a logical expression type from a concrete Arrow in-memory
|
||||
/// array type
|
||||
static Status FromArrow(const ::arrow::DataType& type, LogicalTypePtr* out);
|
||||
|
||||
protected:
|
||||
explicit LogicalType(Id id) : id_(id) {}
|
||||
Id id_;
|
||||
};
|
||||
|
||||
namespace type {
|
||||
|
||||
/// \brief Logical type for any value type
|
||||
class ARROW_EXPORT Any : public LogicalType {
|
||||
public:
|
||||
Any() : LogicalType(LogicalType::ANY) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for null
|
||||
class ARROW_EXPORT Null : public LogicalType {
|
||||
public:
|
||||
Null() : LogicalType(LogicalType::NULL_) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for concrete boolean
|
||||
class ARROW_EXPORT Bool : public LogicalType {
|
||||
public:
|
||||
Bool() : LogicalType(LogicalType::BOOL) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for any number (integer or floating point)
|
||||
class ARROW_EXPORT Number : public LogicalType {
|
||||
public:
|
||||
Number() : Number(LogicalType::NUMBER) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
explicit Number(Id type_id) : LogicalType(type_id) {}
|
||||
};
|
||||
|
||||
/// \brief Logical type for any integer
|
||||
class ARROW_EXPORT Integer : public Number {
|
||||
public:
|
||||
Integer() : Integer(LogicalType::INTEGER) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
explicit Integer(Id type_id) : Number(type_id) {}
|
||||
};
|
||||
|
||||
/// \brief Logical type for any floating point number
|
||||
class ARROW_EXPORT Floating : public Number {
|
||||
public:
|
||||
Floating() : Floating(LogicalType::FLOATING) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
explicit Floating(Id type_id) : Number(type_id) {}
|
||||
};
|
||||
|
||||
/// \brief Logical type for any signed integer
|
||||
class ARROW_EXPORT SignedInteger : public Integer {
|
||||
public:
|
||||
SignedInteger() : SignedInteger(LogicalType::SIGNED_INTEGER) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
explicit SignedInteger(Id type_id) : Integer(type_id) {}
|
||||
};
|
||||
|
||||
/// \brief Logical type for any unsigned integer
|
||||
class ARROW_EXPORT UnsignedInteger : public Integer {
|
||||
public:
|
||||
UnsignedInteger() : UnsignedInteger(LogicalType::UNSIGNED_INTEGER) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
explicit UnsignedInteger(Id type_id) : Integer(type_id) {}
|
||||
};
|
||||
|
||||
/// \brief Logical type for int8
|
||||
class ARROW_EXPORT Int8 : public SignedInteger {
|
||||
public:
|
||||
Int8() : SignedInteger(LogicalType::INT8) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for int16
|
||||
class ARROW_EXPORT Int16 : public SignedInteger {
|
||||
public:
|
||||
Int16() : SignedInteger(LogicalType::INT16) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for int32
|
||||
class ARROW_EXPORT Int32 : public SignedInteger {
|
||||
public:
|
||||
Int32() : SignedInteger(LogicalType::INT32) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for int64
|
||||
class ARROW_EXPORT Int64 : public SignedInteger {
|
||||
public:
|
||||
Int64() : SignedInteger(LogicalType::INT64) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for uint8
|
||||
class ARROW_EXPORT UInt8 : public UnsignedInteger {
|
||||
public:
|
||||
UInt8() : UnsignedInteger(LogicalType::UINT8) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for uint16
|
||||
class ARROW_EXPORT UInt16 : public UnsignedInteger {
|
||||
public:
|
||||
UInt16() : UnsignedInteger(LogicalType::UINT16) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for uint32
|
||||
class ARROW_EXPORT UInt32 : public UnsignedInteger {
|
||||
public:
|
||||
UInt32() : UnsignedInteger(LogicalType::UINT32) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for uint64
|
||||
class ARROW_EXPORT UInt64 : public UnsignedInteger {
|
||||
public:
|
||||
UInt64() : UnsignedInteger(LogicalType::UINT64) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for 16-bit floating point
|
||||
class ARROW_EXPORT Float16 : public Floating {
|
||||
public:
|
||||
Float16() : Floating(LogicalType::FLOAT16) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for 32-bit floating point
|
||||
class ARROW_EXPORT Float32 : public Floating {
|
||||
public:
|
||||
Float32() : Floating(LogicalType::FLOAT32) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for 64-bit floating point
|
||||
class ARROW_EXPORT Float64 : public Floating {
|
||||
public:
|
||||
Float64() : Floating(LogicalType::FLOAT64) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
/// \brief Logical type for variable-size binary
|
||||
class ARROW_EXPORT Binary : public LogicalType {
|
||||
public:
|
||||
Binary() : Binary(LogicalType::BINARY) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
|
||||
protected:
|
||||
explicit Binary(Id type_id) : LogicalType(type_id) {}
|
||||
};
|
||||
|
||||
/// \brief Logical type for variable-size binary
|
||||
class ARROW_EXPORT Utf8 : public Binary {
|
||||
public:
|
||||
Utf8() : Binary(LogicalType::UTF8) {}
|
||||
bool IsInstance(const Expr& expr) const override;
|
||||
std::string ToString() const override;
|
||||
};
|
||||
|
||||
#define SIMPLE_TYPE_FACTORY(NAME) ARROW_EXPORT LogicalTypePtr NAME();
|
||||
|
||||
SIMPLE_TYPE_FACTORY(any);
|
||||
SIMPLE_TYPE_FACTORY(null);
|
||||
SIMPLE_TYPE_FACTORY(boolean);
|
||||
SIMPLE_TYPE_FACTORY(number);
|
||||
SIMPLE_TYPE_FACTORY(integer);
|
||||
SIMPLE_TYPE_FACTORY(signed_integer);
|
||||
SIMPLE_TYPE_FACTORY(unsigned_integer);
|
||||
SIMPLE_TYPE_FACTORY(floating);
|
||||
SIMPLE_TYPE_FACTORY(int8);
|
||||
SIMPLE_TYPE_FACTORY(int16);
|
||||
SIMPLE_TYPE_FACTORY(int32);
|
||||
SIMPLE_TYPE_FACTORY(int64);
|
||||
SIMPLE_TYPE_FACTORY(uint8);
|
||||
SIMPLE_TYPE_FACTORY(uint16);
|
||||
SIMPLE_TYPE_FACTORY(uint32);
|
||||
SIMPLE_TYPE_FACTORY(uint64);
|
||||
SIMPLE_TYPE_FACTORY(float16);
|
||||
SIMPLE_TYPE_FACTORY(float32);
|
||||
SIMPLE_TYPE_FACTORY(float64);
|
||||
SIMPLE_TYPE_FACTORY(binary);
|
||||
SIMPLE_TYPE_FACTORY(utf8);
|
||||
|
||||
#undef SIMPLE_TYPE_FACTORY
|
||||
|
||||
} // namespace type
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,52 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/compute/type_fwd.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Status;
|
||||
|
||||
namespace compute {
|
||||
|
||||
/// \brief An operation is a node in a computation graph, taking input data
|
||||
/// expression dependencies and emitting an output expression
|
||||
class ARROW_EXPORT Operation : public std::enable_shared_from_this<Operation> {
|
||||
public:
|
||||
virtual ~Operation() = default;
|
||||
|
||||
/// \brief Check input expression arguments and output the type of resulting
|
||||
/// expression that this operation produces. If the input arguments are
|
||||
/// invalid, error Status is returned
|
||||
/// \param[out] out the returned well-typed expression
|
||||
/// \return success or failure
|
||||
virtual Status ToExpr(ExprPtr* out) const = 0;
|
||||
|
||||
/// \brief Return the input expressions used to instantiate the
|
||||
/// operation. The default implementation returns an empty vector
|
||||
/// \return a vector of expressions
|
||||
virtual std::vector<ExprPtr> input_args() const;
|
||||
};
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,110 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_COMPUTE_TEST_UTIL_H
|
||||
#define ARROW_COMPUTE_TEST_UTIL_H
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/memory_pool.h"
|
||||
#include "arrow/testing/gtest_util.h"
|
||||
#include "arrow/testing/util.h"
|
||||
#include "arrow/type.h"
|
||||
|
||||
#include "arrow/compute/context.h"
|
||||
#include "arrow/compute/kernel.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace compute {
|
||||
|
||||
class ComputeFixture {
|
||||
public:
|
||||
ComputeFixture() : ctx_(default_memory_pool()) {}
|
||||
|
||||
protected:
|
||||
FunctionContext ctx_;
|
||||
};
|
||||
|
||||
class MockUnaryKernel : public UnaryKernel {
|
||||
public:
|
||||
MOCK_METHOD3(Call, Status(FunctionContext* ctx, const Datum& input, Datum* out));
|
||||
MOCK_CONST_METHOD0(out_type, std::shared_ptr<DataType>());
|
||||
};
|
||||
|
||||
class MockBinaryKernel : public BinaryKernel {
|
||||
public:
|
||||
MOCK_METHOD4(Call, Status(FunctionContext* ctx, const Datum& left, const Datum& right,
|
||||
Datum* out));
|
||||
MOCK_CONST_METHOD0(out_type, std::shared_ptr<DataType>());
|
||||
};
|
||||
|
||||
template <typename Type, typename T>
|
||||
std::shared_ptr<Array> _MakeArray(const std::shared_ptr<DataType>& type,
|
||||
const std::vector<T>& values,
|
||||
const std::vector<bool>& is_valid) {
|
||||
std::shared_ptr<Array> result;
|
||||
if (is_valid.size() > 0) {
|
||||
ArrayFromVector<Type, T>(type, is_valid, values, &result);
|
||||
} else {
|
||||
ArrayFromVector<Type, T>(type, values, &result);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename Type, typename Enable = void>
|
||||
struct DatumEqual {};
|
||||
|
||||
template <typename Type>
|
||||
struct DatumEqual<Type, typename std::enable_if<IsFloatingPoint<Type>::value>::type> {
|
||||
static constexpr double kArbitraryDoubleErrorBound = 1.0;
|
||||
using ScalarType = typename TypeTraits<Type>::ScalarType;
|
||||
|
||||
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
|
||||
ASSERT_EQ(lhs.kind(), rhs.kind());
|
||||
if (lhs.kind() == Datum::SCALAR) {
|
||||
auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
|
||||
auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
|
||||
ASSERT_EQ(left->is_valid, right->is_valid);
|
||||
ASSERT_EQ(left->type->id(), right->type->id());
|
||||
ASSERT_NEAR(left->value, right->value, kArbitraryDoubleErrorBound);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Type>
|
||||
struct DatumEqual<Type, typename std::enable_if<!IsFloatingPoint<Type>::value>::type> {
|
||||
using ScalarType = typename TypeTraits<Type>::ScalarType;
|
||||
static void EnsureEqual(const Datum& lhs, const Datum& rhs) {
|
||||
ASSERT_EQ(lhs.kind(), rhs.kind());
|
||||
if (lhs.kind() == Datum::SCALAR) {
|
||||
auto left = internal::checked_cast<const ScalarType*>(lhs.scalar().get());
|
||||
auto right = internal::checked_cast<const ScalarType*>(rhs.scalar().get());
|
||||
ASSERT_EQ(left->is_valid, right->is_valid);
|
||||
ASSERT_EQ(left->type->id(), right->type->id());
|
||||
ASSERT_EQ(left->value, right->value);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
||||
|
||||
#endif
|
@ -1,38 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/type_fwd.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace compute {
|
||||
|
||||
class Expr;
|
||||
class LogicalType;
|
||||
class Operation;
|
||||
|
||||
using ArrowTypePtr = std::shared_ptr<::arrow::DataType>;
|
||||
using ExprPtr = std::shared_ptr<Expr>;
|
||||
using ConstOpPtr = std::shared_ptr<const Operation>;
|
||||
using OpPtr = std::shared_ptr<Operation>;
|
||||
using LogicalTypePtr = std::shared_ptr<LogicalType>;
|
||||
|
||||
} // namespace compute
|
||||
} // namespace arrow
|
@ -1,24 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_API_H
|
||||
#define ARROW_CSV_API_H
|
||||
|
||||
#include "arrow/csv/options.h"
|
||||
#include "arrow/csv/reader.h"
|
||||
|
||||
#endif // ARROW_CSV_API_H
|
@ -1,69 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_CHUNKER_H
|
||||
#define ARROW_CSV_CHUNKER_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#include "arrow/csv/options.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace csv {
|
||||
|
||||
/// \class Chunker
|
||||
/// \brief A reusable block-based chunker for CSV data
|
||||
///
|
||||
/// The chunker takes a block of CSV data and finds a suitable place
|
||||
/// to cut it up without splitting a row.
|
||||
/// If the block is truncated (i.e. not all data can be chunked), it is up
|
||||
/// to the caller to arrange the next block to start with the trailing data.
|
||||
///
|
||||
/// Note: if the previous block ends with CR (0x0d) and a new block starts
|
||||
/// with LF (0x0a), the chunker will consider the leading newline as an empty line.
|
||||
class ARROW_EXPORT Chunker {
|
||||
public:
|
||||
explicit Chunker(ParseOptions options);
|
||||
|
||||
/// \brief Carve up a chunk in a block of data
|
||||
///
|
||||
/// Process a block of CSV data, reading up to size bytes.
|
||||
/// The number of bytes in the chunk is returned in out_size.
|
||||
Status Process(const char* data, uint32_t size, uint32_t* out_size);
|
||||
|
||||
protected:
|
||||
ARROW_DISALLOW_COPY_AND_ASSIGN(Chunker);
|
||||
|
||||
// Like Process(), but specialized for some parsing options
|
||||
template <bool quoting, bool escaping>
|
||||
Status ProcessSpecialized(const char* data, uint32_t size, uint32_t* out_size);
|
||||
|
||||
// Detect a single line from the data pointer. Return the line end,
|
||||
// or nullptr if the remaining line is truncated.
|
||||
template <bool quoting, bool escaping>
|
||||
inline const char* ReadLine(const char* data, const char* data_end);
|
||||
|
||||
ParseOptions options_;
|
||||
};
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_CHUNKER_H
|
@ -1,87 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_COLUMN_BUILDER_H
|
||||
#define ARROW_CSV_COLUMN_BUILDER_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/array.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class ChunkedArray;
|
||||
class DataType;
|
||||
|
||||
namespace internal {
|
||||
|
||||
class TaskGroup;
|
||||
|
||||
} // namespace internal
|
||||
|
||||
namespace csv {
|
||||
|
||||
class BlockParser;
|
||||
struct ConvertOptions;
|
||||
|
||||
class ARROW_EXPORT ColumnBuilder {
|
||||
public:
|
||||
virtual ~ColumnBuilder() = default;
|
||||
|
||||
/// Spawn a task that will try to convert and append the given CSV block.
|
||||
/// All calls to Append() should happen on the same thread, otherwise
|
||||
/// call Insert() instead.
|
||||
virtual void Append(const std::shared_ptr<BlockParser>& parser);
|
||||
|
||||
/// Spawn a task that will try to convert and insert the given CSV block
|
||||
virtual void Insert(int64_t block_index,
|
||||
const std::shared_ptr<BlockParser>& parser) = 0;
|
||||
|
||||
/// Return the final chunked array. The TaskGroup _must_ have finished!
|
||||
virtual Status Finish(std::shared_ptr<ChunkedArray>* out) = 0;
|
||||
|
||||
/// Change the task group. The previous TaskGroup _must_ have finished!
|
||||
void SetTaskGroup(const std::shared_ptr<internal::TaskGroup>& task_group);
|
||||
|
||||
std::shared_ptr<internal::TaskGroup> task_group() { return task_group_; }
|
||||
|
||||
/// Construct a strictly-typed ColumnBuilder.
|
||||
static Status Make(const std::shared_ptr<DataType>& type, int32_t col_index,
|
||||
const ConvertOptions& options,
|
||||
const std::shared_ptr<internal::TaskGroup>& task_group,
|
||||
std::shared_ptr<ColumnBuilder>* out);
|
||||
|
||||
/// Construct a type-inferring ColumnBuilder.
|
||||
static Status Make(int32_t col_index, const ConvertOptions& options,
|
||||
const std::shared_ptr<internal::TaskGroup>& task_group,
|
||||
std::shared_ptr<ColumnBuilder>* out);
|
||||
|
||||
protected:
|
||||
explicit ColumnBuilder(const std::shared_ptr<internal::TaskGroup>& task_group)
|
||||
: task_group_(task_group) {}
|
||||
|
||||
std::shared_ptr<internal::TaskGroup> task_group_;
|
||||
ArrayVector chunks_;
|
||||
};
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_COLUMN_BUILDER_H
|
@ -1,68 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_CONVERTER_H
|
||||
#define ARROW_CSV_CONVERTER_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/csv/options.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class Array;
|
||||
class DataType;
|
||||
class MemoryPool;
|
||||
class Status;
|
||||
|
||||
namespace csv {
|
||||
|
||||
class BlockParser;
|
||||
|
||||
class ARROW_EXPORT Converter {
|
||||
public:
|
||||
Converter(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
|
||||
MemoryPool* pool);
|
||||
virtual ~Converter() = default;
|
||||
|
||||
virtual Status Convert(const BlockParser& parser, int32_t col_index,
|
||||
std::shared_ptr<Array>* out) = 0;
|
||||
|
||||
std::shared_ptr<DataType> type() const { return type_; }
|
||||
|
||||
static Status Make(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
|
||||
std::shared_ptr<Converter>* out);
|
||||
static Status Make(const std::shared_ptr<DataType>& type, const ConvertOptions& options,
|
||||
MemoryPool* pool, std::shared_ptr<Converter>* out);
|
||||
|
||||
protected:
|
||||
ARROW_DISALLOW_COPY_AND_ASSIGN(Converter);
|
||||
|
||||
virtual Status Initialize() = 0;
|
||||
|
||||
const ConvertOptions options_;
|
||||
MemoryPool* pool_;
|
||||
std::shared_ptr<DataType> type_;
|
||||
};
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_CONVERTER_H
|
@ -1,98 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_OPTIONS_H
|
||||
#define ARROW_CSV_OPTIONS_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class DataType;
|
||||
|
||||
namespace csv {
|
||||
|
||||
struct ARROW_EXPORT ParseOptions {
|
||||
// Parsing options
|
||||
|
||||
// Field delimiter
|
||||
char delimiter = ',';
|
||||
// Whether quoting is used
|
||||
bool quoting = true;
|
||||
// Quoting character (if `quoting` is true)
|
||||
char quote_char = '"';
|
||||
// Whether a quote inside a value is double-quoted
|
||||
bool double_quote = true;
|
||||
// Whether escaping is used
|
||||
bool escaping = false;
|
||||
// Escaping character (if `escaping` is true)
|
||||
char escape_char = '\\';
|
||||
// Whether values are allowed to contain CR (0x0d) and LF (0x0a) characters
|
||||
bool newlines_in_values = false;
|
||||
// Whether empty lines are ignored. If false, an empty line represents
|
||||
// a single empty value (assuming a one-column CSV file).
|
||||
bool ignore_empty_lines = true;
|
||||
|
||||
// XXX Should this be in ReadOptions?
|
||||
// Number of header rows to skip (including the first row containing column names)
|
||||
int32_t header_rows = 1;
|
||||
|
||||
static ParseOptions Defaults();
|
||||
};
|
||||
|
||||
struct ARROW_EXPORT ConvertOptions {
|
||||
// Conversion options
|
||||
|
||||
// Whether to check UTF8 validity of string columns
|
||||
bool check_utf8 = true;
|
||||
// Optional per-column types (disabling type inference on those columns)
|
||||
std::unordered_map<std::string, std::shared_ptr<DataType>> column_types;
|
||||
// Recognized spellings for null values
|
||||
std::vector<std::string> null_values;
|
||||
// Recognized spellings for boolean values
|
||||
std::vector<std::string> true_values;
|
||||
std::vector<std::string> false_values;
|
||||
// Whether string / binary columns can have null values.
|
||||
// If true, then strings in "null_values" are considered null for string columns.
|
||||
// If false, then all strings are valid string values.
|
||||
bool strings_can_be_null = false;
|
||||
|
||||
static ConvertOptions Defaults();
|
||||
};
|
||||
|
||||
struct ARROW_EXPORT ReadOptions {
|
||||
// Reader options
|
||||
|
||||
// Whether to use the global CPU thread pool
|
||||
bool use_threads = true;
|
||||
// Block size we request from the IO layer; also determines the size of
|
||||
// chunks when use_threads is true
|
||||
int32_t block_size = 1 << 20; // 1 MB
|
||||
|
||||
static ReadOptions Defaults();
|
||||
};
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_OPTIONS_H
|
@ -1,149 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_PARSER_H
|
||||
#define ARROW_CSV_PARSER_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/buffer.h"
|
||||
#include "arrow/csv/options.h"
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/macros.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class MemoryPool;
|
||||
|
||||
namespace csv {
|
||||
|
||||
constexpr int32_t kMaxParserNumRows = 100000;
|
||||
|
||||
/// \class BlockParser
|
||||
/// \brief A reusable block-based parser for CSV data
|
||||
///
|
||||
/// The parser takes a block of CSV data and delimits rows and fields,
|
||||
/// unquoting and unescaping them on the fly. Parsed data is own by the
|
||||
/// parser, so the original buffer can be discarded after Parse() returns.
|
||||
///
|
||||
/// If the block is truncated (i.e. not all data can be parsed), it is up
|
||||
/// to the caller to arrange the next block to start with the trailing data.
|
||||
/// Also, if the previous block ends with CR (0x0d) and a new block starts
|
||||
/// with LF (0x0a), the parser will consider the leading newline as an empty
|
||||
/// line; the caller should therefore strip it.
|
||||
class ARROW_EXPORT BlockParser {
|
||||
public:
|
||||
explicit BlockParser(ParseOptions options, int32_t num_cols = -1,
|
||||
int32_t max_num_rows = kMaxParserNumRows);
|
||||
explicit BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_cols = -1,
|
||||
int32_t max_num_rows = kMaxParserNumRows);
|
||||
|
||||
/// \brief Parse a block of data
|
||||
///
|
||||
/// Parse a block of CSV data, ingesting up to max_num_rows rows.
|
||||
/// The number of bytes actually parsed is returned in out_size.
|
||||
Status Parse(const char* data, uint32_t size, uint32_t* out_size);
|
||||
|
||||
/// \brief Parse the final block of data
|
||||
///
|
||||
/// Like Parse(), but called with the final block in a file.
|
||||
/// The last row may lack a trailing line separator.
|
||||
Status ParseFinal(const char* data, uint32_t size, uint32_t* out_size);
|
||||
|
||||
/// \brief Return the number of parsed rows
|
||||
int32_t num_rows() const { return num_rows_; }
|
||||
/// \brief Return the number of parsed columns
|
||||
int32_t num_cols() const { return num_cols_; }
|
||||
/// \brief Return the total size in bytes of parsed data
|
||||
uint32_t num_bytes() const { return parsed_size_; }
|
||||
|
||||
/// \brief Visit parsed values in a column
|
||||
///
|
||||
/// The signature of the visitor is
|
||||
/// Status(const uint8_t* data, uint32_t size, bool quoted)
|
||||
template <typename Visitor>
|
||||
Status VisitColumn(int32_t col_index, Visitor&& visit) const {
|
||||
for (size_t buf_index = 0; buf_index < values_buffers_.size(); ++buf_index) {
|
||||
const auto& values_buffer = values_buffers_[buf_index];
|
||||
const auto values = reinterpret_cast<const ValueDesc*>(values_buffer->data());
|
||||
const auto max_pos =
|
||||
static_cast<int32_t>(values_buffer->size() / sizeof(ValueDesc)) - 1;
|
||||
for (int32_t pos = col_index; pos < max_pos; pos += num_cols_) {
|
||||
auto start = values[pos].offset;
|
||||
auto stop = values[pos + 1].offset;
|
||||
auto quoted = values[pos + 1].quoted;
|
||||
ARROW_RETURN_NOT_OK(visit(parsed_ + start, stop - start, quoted));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
protected:
|
||||
ARROW_DISALLOW_COPY_AND_ASSIGN(BlockParser);
|
||||
|
||||
Status DoParse(const char* data, uint32_t size, bool is_final, uint32_t* out_size);
|
||||
template <typename SpecializedOptions>
|
||||
Status DoParseSpecialized(const char* data, uint32_t size, bool is_final,
|
||||
uint32_t* out_size);
|
||||
|
||||
template <typename SpecializedOptions, typename ValuesWriter, typename ParsedWriter>
|
||||
Status ParseChunk(ValuesWriter* values_writer, ParsedWriter* parsed_writer,
|
||||
const char* data, const char* data_end, bool is_final,
|
||||
int32_t rows_in_chunk, const char** out_data, bool* finished_parsing);
|
||||
|
||||
// Parse a single line from the data pointer
|
||||
template <typename SpecializedOptions, typename ValuesWriter, typename ParsedWriter>
|
||||
Status ParseLine(ValuesWriter* values_writer, ParsedWriter* parsed_writer,
|
||||
const char* data, const char* data_end, bool is_final,
|
||||
const char** out_data);
|
||||
|
||||
MemoryPool* pool_;
|
||||
const ParseOptions options_;
|
||||
// The number of rows parsed from the block
|
||||
int32_t num_rows_;
|
||||
// The number of columns (can be -1 at start)
|
||||
int32_t num_cols_;
|
||||
// The maximum number of rows to parse from this block
|
||||
int32_t max_num_rows_;
|
||||
|
||||
// Linear scratchpad for parsed values
|
||||
struct ValueDesc {
|
||||
uint32_t offset : 31;
|
||||
bool quoted : 1;
|
||||
};
|
||||
|
||||
// XXX should we ensure the parsed buffer is padded with 8 or 16 excess zero bytes?
|
||||
// It may help with null parsing...
|
||||
std::vector<std::shared_ptr<Buffer>> values_buffers_;
|
||||
std::shared_ptr<Buffer> parsed_buffer_;
|
||||
const uint8_t* parsed_;
|
||||
int32_t values_size_;
|
||||
int32_t parsed_size_;
|
||||
|
||||
class ResizableValuesWriter;
|
||||
class PresizedValuesWriter;
|
||||
class PresizedParsedWriter;
|
||||
};
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_PARSER_H
|
@ -1,53 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_READER_H
|
||||
#define ARROW_CSV_READER_H
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "arrow/csv/options.h" // IWYU pragma: keep
|
||||
#include "arrow/status.h"
|
||||
#include "arrow/util/visibility.h"
|
||||
|
||||
namespace arrow {
|
||||
|
||||
class MemoryPool;
|
||||
class Table;
|
||||
|
||||
namespace io {
|
||||
class InputStream;
|
||||
} // namespace io
|
||||
|
||||
namespace csv {
|
||||
|
||||
class ARROW_EXPORT TableReader {
|
||||
public:
|
||||
virtual ~TableReader() = default;
|
||||
|
||||
virtual Status Read(std::shared_ptr<Table>* out) = 0;
|
||||
|
||||
// XXX pass optional schema?
|
||||
static Status Make(MemoryPool* pool, std::shared_ptr<io::InputStream> input,
|
||||
const ReadOptions&, const ParseOptions&, const ConvertOptions&,
|
||||
std::shared_ptr<TableReader>* out);
|
||||
};
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_READER_H
|
@ -1,71 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#ifndef ARROW_CSV_TEST_COMMON_H
|
||||
#define ARROW_CSV_TEST_COMMON_H
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "arrow/csv/parser.h"
|
||||
#include "arrow/testing/gtest_util.h"
|
||||
|
||||
namespace arrow {
|
||||
namespace csv {
|
||||
|
||||
std::string MakeCSVData(std::vector<std::string> lines) {
|
||||
std::string s;
|
||||
for (const auto& line : lines) {
|
||||
s += line;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
// Make a BlockParser from a vector of lines representing a CSV file
|
||||
void MakeCSVParser(std::vector<std::string> lines, ParseOptions options,
|
||||
std::shared_ptr<BlockParser>* out) {
|
||||
auto csv = MakeCSVData(lines);
|
||||
auto parser = std::make_shared<BlockParser>(options);
|
||||
uint32_t out_size;
|
||||
ASSERT_OK(parser->Parse(csv.data(), static_cast<uint32_t>(csv.size()), &out_size));
|
||||
ASSERT_EQ(out_size, csv.size()) << "trailing CSV data not parsed";
|
||||
*out = parser;
|
||||
}
|
||||
|
||||
void MakeCSVParser(std::vector<std::string> lines, std::shared_ptr<BlockParser>* out) {
|
||||
MakeCSVParser(lines, ParseOptions::Defaults(), out);
|
||||
}
|
||||
|
||||
// Make a BlockParser from a vector of strings representing a single CSV column
|
||||
void MakeColumnParser(std::vector<std::string> items, std::shared_ptr<BlockParser>* out) {
|
||||
auto options = ParseOptions::Defaults();
|
||||
// Need this to test for null (empty) values
|
||||
options.ignore_empty_lines = false;
|
||||
std::vector<std::string> lines;
|
||||
for (const auto& item : items) {
|
||||
lines.push_back(item + '\n');
|
||||
}
|
||||
MakeCSVParser(lines, options, out);
|
||||
ASSERT_EQ((*out)->num_cols(), 1) << "Should have seen only 1 CSV column";
|
||||
ASSERT_EQ((*out)->num_rows(), items.size());
|
||||
}
|
||||
|
||||
} // namespace csv
|
||||
} // namespace arrow
|
||||
|
||||
#endif // ARROW_CSV_TEST_COMMON_H
|
@ -1,26 +0,0 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "arrow/dataset/dataset.h"
|
||||
#include "arrow/dataset/discovery.h"
|
||||
#include "arrow/dataset/file_base.h"
|
||||
#include "arrow/dataset/file_csv.h"
|
||||
#include "arrow/dataset/file_feather.h"
|
||||
#include "arrow/dataset/file_parquet.h"
|
||||
#include "arrow/dataset/scanner.h"
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user