Program Listing for File AnnoyBackend.hpp

Return to documentation for file (include\util\db\backends\AnnoyBackend.hpp)

#pragma once

#include "util/common/Result.hpp"
#include "util/db/DbTypes.hpp"
#include "util/db/nearest/Types.hpp"

#include <annoy/annoylib.h>
#include <annoy/kissrandom.h>

#include <algorithm>
#include <cstddef>
#include <cstdlib>
#include <cstring>
#include <filesystem>
#include <fstream>
#include <map>
#include <memory>
#include <optional>
#include <span>
#include <sstream>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>

namespace PDJE_UTIL::db::backends {

namespace detail {

inline std::string
hex_encode(std::string_view input)
{
    static constexpr char hex_digits[] = "0123456789abcdef";
    std::string           encoded;
    encoded.reserve(input.size() * 2);
    for (unsigned char ch : input) {
        encoded.push_back(hex_digits[(ch >> 4) & 0x0F]);
        encoded.push_back(hex_digits[ch & 0x0F]);
    }
    return encoded;
}

inline std::string
hex_encode_bytes(std::span<const std::byte> input)
{
    if (input.empty()) {
        return {};
    }
    return hex_encode(std::string_view(reinterpret_cast<const char *>(input.data()),
                                       input.size_bytes()));
}

inline std::optional<std::vector<std::byte>>
hex_decode_bytes(std::string_view input)
{
    if ((input.size() % 2) != 0) {
        return std::nullopt;
    }

    auto decode_nibble = [](unsigned char ch) -> int {
        if (ch >= '0' && ch <= '9') {
            return ch - '0';
        }
        if (ch >= 'a' && ch <= 'f') {
            return 10 + ch - 'a';
        }
        if (ch >= 'A' && ch <= 'F') {
            return 10 + ch - 'A';
        }
        return -1;
    };

    std::vector<std::byte> bytes;
    bytes.reserve(input.size() / 2);
    for (std::size_t i = 0; i < input.size(); i += 2) {
        const int hi = decode_nibble(static_cast<unsigned char>(input[i]));
        const int lo = decode_nibble(static_cast<unsigned char>(input[i + 1]));
        if (hi < 0 || lo < 0) {
            return std::nullopt;
        }
        bytes.push_back(static_cast<std::byte>((hi << 4) | lo));
    }
    return bytes;
}

inline std::optional<std::string>
hex_decode_text(std::string_view input)
{
    auto bytes = hex_decode_bytes(input);
    if (!bytes.has_value()) {
        return std::nullopt;
    }
    if (bytes->empty()) {
        return std::string {};
    }
    return std::string(reinterpret_cast<const char *>(bytes->data()), bytes->size());
}

inline std::string
encode_embedding(std::span<const float> embedding)
{
    if (embedding.empty()) {
        return {};
    }
    const auto *bytes = reinterpret_cast<const std::byte *>(embedding.data());
    return hex_encode_bytes(std::span<const std::byte>(
        bytes, embedding.size() * sizeof(float)));
}

inline std::optional<nearest::Embedding>
decode_embedding(std::string_view input, std::size_t expected_dimension)
{
    auto bytes = hex_decode_bytes(input);
    if (!bytes.has_value()) {
        return std::nullopt;
    }
    if ((bytes->size() % sizeof(float)) != 0) {
        return std::nullopt;
    }

    nearest::Embedding embedding(bytes->size() / sizeof(float));
    if (!bytes->empty()) {
        std::memcpy(embedding.data(), bytes->data(), bytes->size());
    }
    if (embedding.size() != expected_dimension) {
        return std::nullopt;
    }
    return embedding;
}

inline std::filesystem::path
annoy_manifest_path(const std::filesystem::path &root_path)
{
    return root_path / "records.tsv";
}

inline std::filesystem::path
annoy_index_path(const std::filesystem::path &root_path)
{
    return root_path / "index.ann";
}

inline void
free_annoy_error(char *error)
{
    if (error != nullptr) {
        std::free(error);
    }
}

inline std::vector<std::string>
split_tsv_line(const std::string &line)
{
    std::vector<std::string> fields;
    std::istringstream       stream(line);
    std::string              field;
    while (std::getline(stream, field, '\t')) {
        fields.push_back(field);
    }
    if (!line.empty() && line.back() == '\t') {
        fields.emplace_back();
    }
    return fields;
}

} // namespace detail

struct AnnoyConfig {
    std::filesystem::path root_path;
    OpenOptions           open_options {};
    std::size_t           dimension = 0;
    int                   trees     = 10;
    bool                  prefault  = false;
};

class AnnoyBackend {
  public:
    using config_type = AnnoyConfig;
    using IndexType =
        Annoy::AnnoyIndex<int,
                          float,
                          Annoy::Angular,
                          Annoy::Kiss32Random,
                          Annoy::AnnoyIndexSingleThreadedBuildPolicy>;

    static common::Result<void>
    create(const config_type &cfg)
    {
        if (cfg.root_path.empty()) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "AnnoyConfig.root_path must not be empty." });
        }
        if (cfg.dimension == 0) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "AnnoyConfig.dimension must be greater than zero." });
        }

        std::error_code ec;
        std::filesystem::create_directories(cfg.root_path, ec);
        if (ec) {
            return common::Result<void>::failure(
                { common::StatusCode::io_error, ec.message() });
        }
        return common::Result<void>::success();
    }

    static common::Result<void>
    destroy(const config_type &cfg)
    {
        if (cfg.root_path.empty()) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "AnnoyConfig.root_path must not be empty." });
        }

        std::error_code ec;
        std::filesystem::remove_all(cfg.root_path, ec);
        if (ec) {
            return common::Result<void>::failure(
                { common::StatusCode::io_error, ec.message() });
        }
        return common::Result<void>::success();
    }

    common::Result<void>
    open(const config_type &cfg)
    {
        if (is_open_) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument, "Annoy backend is already open." });
        }
        if (cfg.root_path.empty()) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "AnnoyConfig.root_path must not be empty." });
        }
        if (cfg.dimension == 0) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "AnnoyConfig.dimension must be greater than zero." });
        }
        if (cfg.open_options.read_only &&
            (cfg.open_options.create_if_missing || cfg.open_options.truncate_if_exists)) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "Annoy read-only mode cannot create or truncate the index." });
        }

        config_ = cfg;
        if (config_.open_options.truncate_if_exists) {
            auto destroyed = destroy(config_);
            if (!destroyed.ok()) {
                return destroyed;
            }
        }

        const bool exists = std::filesystem::exists(config_.root_path);
        if (!exists) {
            if (config_.open_options.create_if_missing) {
                auto created = create(config_);
                if (!created.ok()) {
                    return created;
                }
            } else {
                return common::Result<void>::failure(
                    { common::StatusCode::not_found,
                      "Annoy backend directory does not exist." });
            }
        }

        records_.clear();
        id_to_key_.clear();
        index_.reset();
        next_item_id_ = 0;
        index_dirty_  = true;
        is_open_      = true;

        const auto manifest_path = detail::annoy_manifest_path(config_.root_path);
        if (std::filesystem::exists(manifest_path)) {
            auto loaded = load_manifest(manifest_path);
            if (!loaded.ok()) {
                reset_runtime_state();
                return loaded;
            }
        }

        return common::Result<void>::success();
    }

    common::Result<void>
    close()
    {
        if (!is_open_) {
            return common::Result<void>::success();
        }

        common::Result<void> result = common::Result<void>::success();
        if (!config_.open_options.read_only) {
            result = persist();
        }

        reset_runtime_state();
        return result;
    }

    common::Result<bool>
    contains(std::string_view key) const
    {
        if (auto status = require_open(); !status.ok()) {
            return common::Result<bool>::failure(status);
        }
        return common::Result<bool>::success(records_.contains(std::string(key)));
    }

    common::Result<nearest::Item>
    get_item(std::string_view key) const
    {
        if (auto status = require_open(); !status.ok()) {
            return common::Result<nearest::Item>::failure(status);
        }

        const auto found = records_.find(std::string(key));
        if (found == records_.end()) {
            return common::Result<nearest::Item>::failure(
                { common::StatusCode::not_found, "Annoy item was not found." });
        }
        return common::Result<nearest::Item>::success(found->second);
    }

    common::Result<void>
    upsert_item(const nearest::Item &item)
    {
        if (auto status = require_writable(); !status.ok()) {
            return common::Result<void>::failure(status);
        }
        if (item.id.empty()) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument, "Annoy item id must not be empty." });
        }
        if (item.embedding.size() != config_.dimension) {
            return common::Result<void>::failure(
                { common::StatusCode::invalid_argument,
                  "Annoy item embedding size must match the configured dimension." });
        }

        records_[item.id] = item;
        index_dirty_      = true;
        return common::Result<void>::success();
    }

    common::Result<void>
    erase_item(std::string_view key)
    {
        if (auto status = require_writable(); !status.ok()) {
            return common::Result<void>::failure(status);
        }

        records_.erase(std::string(key));
        index_dirty_ = true;
        return common::Result<void>::success();
    }

    common::Result<std::vector<nearest::SearchHit>>
    search(std::span<const float> query, nearest::SearchOptions options) const
    {
        if (auto status = require_open(); !status.ok()) {
            return common::Result<std::vector<nearest::SearchHit>>::failure(status);
        }
        if (query.size() != config_.dimension) {
            return common::Result<std::vector<nearest::SearchHit>>::failure(
                { common::StatusCode::invalid_argument,
                  "Annoy query vector size must match the configured dimension." });
        }
        if (options.limit == 0 || records_.empty()) {
            return common::Result<std::vector<nearest::SearchHit>>::success({});
        }

        auto rebuilt = rebuild_index();
        if (!rebuilt.ok()) {
            return common::Result<std::vector<nearest::SearchHit>>::failure(rebuilt.status());
        }
        if (!index_) {
            return common::Result<std::vector<nearest::SearchHit>>::success({});
        }

        std::vector<int>   ids;
        std::vector<float> distances;
        index_->get_nns_by_vector(query.data(), options.limit, options.search_k, &ids, &distances);

        std::vector<nearest::SearchHit> hits;
        hits.reserve(ids.size());
        for (std::size_t i = 0; i < ids.size(); ++i) {
            const auto found_key = id_to_key_.find(ids[i]);
            if (found_key == id_to_key_.end()) {
                continue;
            }

            const auto found_item = records_.find(found_key->second);
            if (found_item == records_.end()) {
                continue;
            }

            hits.push_back(
                { found_item->second.id,
                  i < distances.size() ? distances[i] : 0.0F,
                  found_item->second.text_payload,
                  found_item->second.bytes_payload });
        }

        return common::Result<std::vector<nearest::SearchHit>>::success(std::move(hits));
    }

    common::Result<std::vector<Key>>
    list_keys() const
    {
        if (auto status = require_open(); !status.ok()) {
            return common::Result<std::vector<Key>>::failure(status);
        }

        std::vector<Key> keys;
        keys.reserve(records_.size());
        for (const auto &[key, _] : records_) {
            keys.push_back(key);
        }
        std::sort(keys.begin(), keys.end());
        return common::Result<std::vector<Key>>::success(std::move(keys));
    }

  private:
    common::Status
    require_open() const
    {
        if (!is_open_) {
            return { common::StatusCode::closed, "Annoy backend is not open." };
        }
        return {};
    }

    common::Status
    require_writable() const
    {
        if (auto status = require_open(); !status.ok()) {
            return status;
        }
        if (config_.open_options.read_only) {
            return { common::StatusCode::unsupported, "Annoy backend is opened read-only." };
        }
        return {};
    }

    common::Result<void>
    rebuild_index() const
    {
        if (!is_open_) {
            return common::Result<void>::failure(
                { common::StatusCode::closed, "Annoy backend is not open." });
        }
        if (!index_dirty_) {
            return common::Result<void>::success();
        }

        index_ = std::make_unique<IndexType>(static_cast<int>(config_.dimension));
        id_to_key_.clear();
        next_item_id_ = 0;

        for (const auto &[key, item] : records_) {
            if (item.embedding.size() != config_.dimension) {
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error,
                      "Annoy manifest contains an embedding with the wrong dimension." });
            }

            const int item_id = next_item_id_++;
            char     *error   = nullptr;
            if (!index_->add_item(item_id, item.embedding.data(), &error)) {
                std::string message = error != nullptr ? error : "Annoy add_item failed.";
                detail::free_annoy_error(error);
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error, std::move(message) });
            }
            id_to_key_[item_id] = key;
        }

        if (!records_.empty()) {
            char *error = nullptr;
            if (!index_->build(config_.trees, -1, &error)) {
                std::string message = error != nullptr ? error : "Annoy build failed.";
                detail::free_annoy_error(error);
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error, std::move(message) });
            }
        }

        index_dirty_ = false;
        return common::Result<void>::success();
    }

    common::Result<void>
    load_manifest(const std::filesystem::path &manifest_path)
    {
        std::ifstream input(manifest_path);
        if (!input) {
            return common::Result<void>::failure(
                { common::StatusCode::io_error, "Failed to open Annoy manifest file." });
        }

        std::string line;
        while (std::getline(input, line)) {
            if (line.empty()) {
                continue;
            }

            const auto fields = detail::split_tsv_line(line);
            if (fields.size() != 6) {
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error, "Annoy manifest is malformed." });
            }

            auto id = detail::hex_decode_text(fields[0]);
            auto embedding = detail::decode_embedding(fields[1], config_.dimension);
            if (!id.has_value() || !embedding.has_value()) {
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error,
                      "Annoy manifest contains malformed item data." });
            }

            nearest::Item item;
            item.id        = std::move(*id);
            item.embedding = std::move(*embedding);

            if (fields[2] == "1") {
                auto text_payload = detail::hex_decode_text(fields[3]);
                if (!text_payload.has_value()) {
                    return common::Result<void>::failure(
                        { common::StatusCode::backend_error,
                          "Annoy manifest text payload is malformed." });
                }
                item.text_payload = std::move(*text_payload);
            } else if (fields[2] != "0") {
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error,
                      "Annoy manifest text payload flag is malformed." });
            }

            if (fields[4] == "1") {
                auto bytes_payload = detail::hex_decode_bytes(fields[5]);
                if (!bytes_payload.has_value()) {
                    return common::Result<void>::failure(
                        { common::StatusCode::backend_error,
                          "Annoy manifest bytes payload is malformed." });
                }
                item.bytes_payload = std::move(*bytes_payload);
            } else if (fields[4] != "0") {
                return common::Result<void>::failure(
                    { common::StatusCode::backend_error,
                      "Annoy manifest bytes payload flag is malformed." });
            }

            records_[item.id] = std::move(item);
        }

        return common::Result<void>::success();
    }

    common::Result<void>
    persist()
    {
        auto rebuilt = rebuild_index();
        if (!rebuilt.ok()) {
            return rebuilt;
        }

        std::error_code ec;
        std::filesystem::create_directories(config_.root_path, ec);
        if (ec) {
            return common::Result<void>::failure(
                { common::StatusCode::io_error, ec.message() });
        }

        const auto manifest_path = detail::annoy_manifest_path(config_.root_path);
        std::ofstream output(manifest_path, std::ios::trunc);
        if (!output) {
            return common::Result<void>::failure(
                { common::StatusCode::io_error, "Failed to write Annoy manifest file." });
        }

        auto keys_result = list_keys();
        if (!keys_result.ok()) {
            return common::Result<void>::failure(keys_result.status());
        }

        for (const auto &key : keys_result.value()) {
            const auto &item = records_.at(key);
            output << detail::hex_encode(item.id) << '\t'
                   << detail::encode_embedding(item.embedding) << '\t'
                   << (item.text_payload.has_value() ? "1" : "0") << '\t'
                   << (item.text_payload.has_value() ? detail::hex_encode(*item.text_payload)
                                                     : std::string {})
                   << '\t'
                   << (item.bytes_payload.has_value() ? "1" : "0") << '\t'
                   << (item.bytes_payload.has_value()
                           ? detail::hex_encode_bytes(*item.bytes_payload)
                           : std::string {})
                   << '\n';
        }
        output.close();

        const auto index_path = detail::annoy_index_path(config_.root_path);
        if (records_.empty()) {
            std::filesystem::remove(index_path, ec);
            return common::Result<void>::success();
        }

        char *error = nullptr;
        if (!index_->save(index_path.string().c_str(), config_.prefault, &error)) {
            std::string message = error != nullptr ? error : "Annoy save failed.";
            detail::free_annoy_error(error);
            return common::Result<void>::failure(
                { common::StatusCode::backend_error, std::move(message) });
        }

        return common::Result<void>::success();
    }

    void
    reset_runtime_state()
    {
        index_.reset();
        id_to_key_.clear();
        records_.clear();
        next_item_id_ = 0;
        index_dirty_  = true;
        is_open_      = false;
    }

    config_type                                    config_ {};
    bool                                           is_open_ = false;
    mutable bool                                   index_dirty_ = true;
    mutable int                                    next_item_id_ = 0;
    std::unordered_map<Key, nearest::Item>         records_ {};
    mutable std::map<int, Key>                     id_to_key_ {};
    mutable std::unique_ptr<IndexType>             index_ {};
};

} // namespace PDJE_UTIL::db::backends