.. _program_listing_file_include_util_db_backends_AnnoyBackend.hpp: Program Listing for File AnnoyBackend.hpp ========================================= |exhale_lsh| :ref:`Return to documentation for file ` (``include\util\db\backends\AnnoyBackend.hpp``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include "util/common/Result.hpp" #include "util/db/DbTypes.hpp" #include "util/db/nearest/Types.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace PDJE_UTIL::db::backends { namespace detail { inline std::string hex_encode(std::string_view input) { static constexpr char hex_digits[] = "0123456789abcdef"; std::string encoded; encoded.reserve(input.size() * 2); for (unsigned char ch : input) { encoded.push_back(hex_digits[(ch >> 4) & 0x0F]); encoded.push_back(hex_digits[ch & 0x0F]); } return encoded; } inline std::string hex_encode_bytes(std::span input) { if (input.empty()) { return {}; } return hex_encode(std::string_view(reinterpret_cast(input.data()), input.size_bytes())); } inline std::optional> hex_decode_bytes(std::string_view input) { if ((input.size() % 2) != 0) { return std::nullopt; } auto decode_nibble = [](unsigned char ch) -> int { if (ch >= '0' && ch <= '9') { return ch - '0'; } if (ch >= 'a' && ch <= 'f') { return 10 + ch - 'a'; } if (ch >= 'A' && ch <= 'F') { return 10 + ch - 'A'; } return -1; }; std::vector bytes; bytes.reserve(input.size() / 2); for (std::size_t i = 0; i < input.size(); i += 2) { const int hi = decode_nibble(static_cast(input[i])); const int lo = decode_nibble(static_cast(input[i + 1])); if (hi < 0 || lo < 0) { return std::nullopt; } bytes.push_back(static_cast((hi << 4) | lo)); } return bytes; } inline std::optional hex_decode_text(std::string_view input) { auto bytes = hex_decode_bytes(input); if (!bytes.has_value()) { return std::nullopt; } if (bytes->empty()) { return std::string {}; } return std::string(reinterpret_cast(bytes->data()), bytes->size()); } inline std::string encode_embedding(std::span embedding) { if (embedding.empty()) { return {}; } const auto *bytes = reinterpret_cast(embedding.data()); return hex_encode_bytes(std::span( bytes, embedding.size() * sizeof(float))); } inline std::optional decode_embedding(std::string_view input, std::size_t expected_dimension) { auto bytes = hex_decode_bytes(input); if (!bytes.has_value()) { return std::nullopt; } if ((bytes->size() % sizeof(float)) != 0) { return std::nullopt; } nearest::Embedding embedding(bytes->size() / sizeof(float)); if (!bytes->empty()) { std::memcpy(embedding.data(), bytes->data(), bytes->size()); } if (embedding.size() != expected_dimension) { return std::nullopt; } return embedding; } inline std::filesystem::path annoy_manifest_path(const std::filesystem::path &root_path) { return root_path / "records.tsv"; } inline std::filesystem::path annoy_index_path(const std::filesystem::path &root_path) { return root_path / "index.ann"; } inline void free_annoy_error(char *error) { if (error != nullptr) { std::free(error); } } inline std::vector split_tsv_line(const std::string &line) { std::vector fields; std::istringstream stream(line); std::string field; while (std::getline(stream, field, '\t')) { fields.push_back(field); } if (!line.empty() && line.back() == '\t') { fields.emplace_back(); } return fields; } } // namespace detail struct AnnoyConfig { std::filesystem::path root_path; OpenOptions open_options {}; std::size_t dimension = 0; int trees = 10; bool prefault = false; }; class AnnoyBackend { public: using config_type = AnnoyConfig; using IndexType = Annoy::AnnoyIndex; static common::Result create(const config_type &cfg) { if (cfg.root_path.empty()) { return common::Result::failure( { common::StatusCode::invalid_argument, "AnnoyConfig.root_path must not be empty." }); } if (cfg.dimension == 0) { return common::Result::failure( { common::StatusCode::invalid_argument, "AnnoyConfig.dimension must be greater than zero." }); } std::error_code ec; std::filesystem::create_directories(cfg.root_path, ec); if (ec) { return common::Result::failure( { common::StatusCode::io_error, ec.message() }); } return common::Result::success(); } static common::Result destroy(const config_type &cfg) { if (cfg.root_path.empty()) { return common::Result::failure( { common::StatusCode::invalid_argument, "AnnoyConfig.root_path must not be empty." }); } std::error_code ec; std::filesystem::remove_all(cfg.root_path, ec); if (ec) { return common::Result::failure( { common::StatusCode::io_error, ec.message() }); } return common::Result::success(); } common::Result open(const config_type &cfg) { if (is_open_) { return common::Result::failure( { common::StatusCode::invalid_argument, "Annoy backend is already open." }); } if (cfg.root_path.empty()) { return common::Result::failure( { common::StatusCode::invalid_argument, "AnnoyConfig.root_path must not be empty." }); } if (cfg.dimension == 0) { return common::Result::failure( { common::StatusCode::invalid_argument, "AnnoyConfig.dimension must be greater than zero." }); } if (cfg.open_options.read_only && (cfg.open_options.create_if_missing || cfg.open_options.truncate_if_exists)) { return common::Result::failure( { common::StatusCode::invalid_argument, "Annoy read-only mode cannot create or truncate the index." }); } config_ = cfg; if (config_.open_options.truncate_if_exists) { auto destroyed = destroy(config_); if (!destroyed.ok()) { return destroyed; } } const bool exists = std::filesystem::exists(config_.root_path); if (!exists) { if (config_.open_options.create_if_missing) { auto created = create(config_); if (!created.ok()) { return created; } } else { return common::Result::failure( { common::StatusCode::not_found, "Annoy backend directory does not exist." }); } } records_.clear(); id_to_key_.clear(); index_.reset(); next_item_id_ = 0; index_dirty_ = true; is_open_ = true; const auto manifest_path = detail::annoy_manifest_path(config_.root_path); if (std::filesystem::exists(manifest_path)) { auto loaded = load_manifest(manifest_path); if (!loaded.ok()) { reset_runtime_state(); return loaded; } } return common::Result::success(); } common::Result close() { if (!is_open_) { return common::Result::success(); } common::Result result = common::Result::success(); if (!config_.open_options.read_only) { result = persist(); } reset_runtime_state(); return result; } common::Result contains(std::string_view key) const { if (auto status = require_open(); !status.ok()) { return common::Result::failure(status); } return common::Result::success(records_.contains(std::string(key))); } common::Result get_item(std::string_view key) const { if (auto status = require_open(); !status.ok()) { return common::Result::failure(status); } const auto found = records_.find(std::string(key)); if (found == records_.end()) { return common::Result::failure( { common::StatusCode::not_found, "Annoy item was not found." }); } return common::Result::success(found->second); } common::Result upsert_item(const nearest::Item &item) { if (auto status = require_writable(); !status.ok()) { return common::Result::failure(status); } if (item.id.empty()) { return common::Result::failure( { common::StatusCode::invalid_argument, "Annoy item id must not be empty." }); } if (item.embedding.size() != config_.dimension) { return common::Result::failure( { common::StatusCode::invalid_argument, "Annoy item embedding size must match the configured dimension." }); } records_[item.id] = item; index_dirty_ = true; return common::Result::success(); } common::Result erase_item(std::string_view key) { if (auto status = require_writable(); !status.ok()) { return common::Result::failure(status); } records_.erase(std::string(key)); index_dirty_ = true; return common::Result::success(); } common::Result> search(std::span query, nearest::SearchOptions options) const { if (auto status = require_open(); !status.ok()) { return common::Result>::failure(status); } if (query.size() != config_.dimension) { return common::Result>::failure( { common::StatusCode::invalid_argument, "Annoy query vector size must match the configured dimension." }); } if (options.limit == 0 || records_.empty()) { return common::Result>::success({}); } auto rebuilt = rebuild_index(); if (!rebuilt.ok()) { return common::Result>::failure(rebuilt.status()); } if (!index_) { return common::Result>::success({}); } std::vector ids; std::vector distances; index_->get_nns_by_vector(query.data(), options.limit, options.search_k, &ids, &distances); std::vector hits; hits.reserve(ids.size()); for (std::size_t i = 0; i < ids.size(); ++i) { const auto found_key = id_to_key_.find(ids[i]); if (found_key == id_to_key_.end()) { continue; } const auto found_item = records_.find(found_key->second); if (found_item == records_.end()) { continue; } hits.push_back( { found_item->second.id, i < distances.size() ? distances[i] : 0.0F, found_item->second.text_payload, found_item->second.bytes_payload }); } return common::Result>::success(std::move(hits)); } common::Result> list_keys() const { if (auto status = require_open(); !status.ok()) { return common::Result>::failure(status); } std::vector keys; keys.reserve(records_.size()); for (const auto &[key, _] : records_) { keys.push_back(key); } std::sort(keys.begin(), keys.end()); return common::Result>::success(std::move(keys)); } private: common::Status require_open() const { if (!is_open_) { return { common::StatusCode::closed, "Annoy backend is not open." }; } return {}; } common::Status require_writable() const { if (auto status = require_open(); !status.ok()) { return status; } if (config_.open_options.read_only) { return { common::StatusCode::unsupported, "Annoy backend is opened read-only." }; } return {}; } common::Result rebuild_index() const { if (!is_open_) { return common::Result::failure( { common::StatusCode::closed, "Annoy backend is not open." }); } if (!index_dirty_) { return common::Result::success(); } index_ = std::make_unique(static_cast(config_.dimension)); id_to_key_.clear(); next_item_id_ = 0; for (const auto &[key, item] : records_) { if (item.embedding.size() != config_.dimension) { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest contains an embedding with the wrong dimension." }); } const int item_id = next_item_id_++; char *error = nullptr; if (!index_->add_item(item_id, item.embedding.data(), &error)) { std::string message = error != nullptr ? error : "Annoy add_item failed."; detail::free_annoy_error(error); return common::Result::failure( { common::StatusCode::backend_error, std::move(message) }); } id_to_key_[item_id] = key; } if (!records_.empty()) { char *error = nullptr; if (!index_->build(config_.trees, -1, &error)) { std::string message = error != nullptr ? error : "Annoy build failed."; detail::free_annoy_error(error); return common::Result::failure( { common::StatusCode::backend_error, std::move(message) }); } } index_dirty_ = false; return common::Result::success(); } common::Result load_manifest(const std::filesystem::path &manifest_path) { std::ifstream input(manifest_path); if (!input) { return common::Result::failure( { common::StatusCode::io_error, "Failed to open Annoy manifest file." }); } std::string line; while (std::getline(input, line)) { if (line.empty()) { continue; } const auto fields = detail::split_tsv_line(line); if (fields.size() != 6) { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest is malformed." }); } auto id = detail::hex_decode_text(fields[0]); auto embedding = detail::decode_embedding(fields[1], config_.dimension); if (!id.has_value() || !embedding.has_value()) { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest contains malformed item data." }); } nearest::Item item; item.id = std::move(*id); item.embedding = std::move(*embedding); if (fields[2] == "1") { auto text_payload = detail::hex_decode_text(fields[3]); if (!text_payload.has_value()) { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest text payload is malformed." }); } item.text_payload = std::move(*text_payload); } else if (fields[2] != "0") { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest text payload flag is malformed." }); } if (fields[4] == "1") { auto bytes_payload = detail::hex_decode_bytes(fields[5]); if (!bytes_payload.has_value()) { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest bytes payload is malformed." }); } item.bytes_payload = std::move(*bytes_payload); } else if (fields[4] != "0") { return common::Result::failure( { common::StatusCode::backend_error, "Annoy manifest bytes payload flag is malformed." }); } records_[item.id] = std::move(item); } return common::Result::success(); } common::Result persist() { auto rebuilt = rebuild_index(); if (!rebuilt.ok()) { return rebuilt; } std::error_code ec; std::filesystem::create_directories(config_.root_path, ec); if (ec) { return common::Result::failure( { common::StatusCode::io_error, ec.message() }); } const auto manifest_path = detail::annoy_manifest_path(config_.root_path); std::ofstream output(manifest_path, std::ios::trunc); if (!output) { return common::Result::failure( { common::StatusCode::io_error, "Failed to write Annoy manifest file." }); } auto keys_result = list_keys(); if (!keys_result.ok()) { return common::Result::failure(keys_result.status()); } for (const auto &key : keys_result.value()) { const auto &item = records_.at(key); output << detail::hex_encode(item.id) << '\t' << detail::encode_embedding(item.embedding) << '\t' << (item.text_payload.has_value() ? "1" : "0") << '\t' << (item.text_payload.has_value() ? detail::hex_encode(*item.text_payload) : std::string {}) << '\t' << (item.bytes_payload.has_value() ? "1" : "0") << '\t' << (item.bytes_payload.has_value() ? detail::hex_encode_bytes(*item.bytes_payload) : std::string {}) << '\n'; } output.close(); const auto index_path = detail::annoy_index_path(config_.root_path); if (records_.empty()) { std::filesystem::remove(index_path, ec); return common::Result::success(); } char *error = nullptr; if (!index_->save(index_path.string().c_str(), config_.prefault, &error)) { std::string message = error != nullptr ? error : "Annoy save failed."; detail::free_annoy_error(error); return common::Result::failure( { common::StatusCode::backend_error, std::move(message) }); } return common::Result::success(); } void reset_runtime_state() { index_.reset(); id_to_key_.clear(); records_.clear(); next_item_id_ = 0; index_dirty_ = true; is_open_ = false; } config_type config_ {}; bool is_open_ = false; mutable bool index_dirty_ = true; mutable int next_item_id_ = 0; std::unordered_map records_ {}; mutable std::map id_to_key_ {}; mutable std::unique_ptr index_ {}; }; } // namespace PDJE_UTIL::db::backends