Skip to content

Commit

Permalink
feat(backend): expose generate method through the lib
Browse files Browse the repository at this point in the history
  • Loading branch information
mfuntowicz committed Oct 26, 2024
1 parent 6416e90 commit ed5dfbf
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 6 deletions.
3 changes: 2 additions & 1 deletion backends/llamacpp/csrc/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

namespace huggingface::tgi::backends::llamacpp {
enum TgiLlamaCppBackendError : uint8_t {
MODEL_FILE_DOESNT_EXIST = 1
MODEL_FILE_DOESNT_EXIST = 1 << 0,
OUTPUT_BUFFER_NOT_BIG_ENOUGH = 1 << 7
};

class TgiLlamaCppBackend {
Expand Down
29 changes: 27 additions & 2 deletions backends/llamacpp/csrc/ffi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
#ifndef TGI_LLAMA_CPP_BACKEND_FFI_HPP
#define TGI_LLAMA_CPP_BACKEND_FFI_HPP

#include <cstdint>
#include <exception>
#include <filesystem>
#include <span>
#include <string_view>

#include <spdlog/spdlog.h>
Expand All @@ -22,8 +24,10 @@ namespace huggingface::tgi::backends::llamacpp::impl {

namespace huggingface::tgi::backends::llamacpp::impl {

class LlamaCppBackendException : std::exception {
struct LlamaCppBackendImplException : std::exception {
TgiLlamaCppBackendError error;

explicit LlamaCppBackendImplException(TgiLlamaCppBackendError error) : error(error) {}
};

class LlamaCppBackendImpl {
Expand All @@ -32,6 +36,27 @@ namespace huggingface::tgi::backends::llamacpp::impl {

public:
LlamaCppBackendImpl(llama_model *model, llama_context *context) : _inner(model, context) {}

size_t Generate(
const rust::Slice<const int32_t> tokens,
rust::Slice <int32_t> out,
uint32_t topK,
uint32_t maxNewTokens
) {
if (out.size() < maxNewTokens) {
throw LlamaCppBackendImplException(TgiLlamaCppBackendError::OUTPUT_BUFFER_NOT_BIG_ENOUGH);
}

std::span<const int32_t> tokens_(tokens.data(), tokens.size());
std::span<int32_t> out_(out.data(), out.size());

const auto nGenerated = _inner.Generate(tokens_, out_, topK, 1.0, 0.0f, 0.0f, maxNewTokens);
if (nGenerated.has_value()) {
return *nGenerated;
} else {
throw LlamaCppBackendImplException(nGenerated.error());
}
}
};

std::unique_ptr<LlamaCppBackendImpl> CreateLlamaCppBackendImpl(rust::Str modelPath) {
Expand All @@ -40,7 +65,7 @@ namespace huggingface::tgi::backends::llamacpp::impl {
auto [model, context] = *maybe;
return std::make_unique<LlamaCppBackendImpl>(model, context);
} else {
throw LlamaCppBackendException();
throw LlamaCppBackendImplException(maybe.error());
}
}
}
Expand Down
13 changes: 10 additions & 3 deletions backends/llamacpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,15 @@ mod ffi {
type LlamaCppBackendImpl;

#[rust_name = "create_llamacpp_backend"]
fn CreateLlamaCppBackendImpl(
modelPath: &str,
) -> Result<UniquePtr<LlamaCppBackendImpl>>;
fn CreateLlamaCppBackendImpl(modelPath: &str) -> Result<UniquePtr<LlamaCppBackendImpl>>;

#[rust_name = "generate"]
fn Generate(
self: Pin<&mut LlamaCppBackendImpl>,
tokens: &[i32],
generated: &mut [i32],
top_k: u32,
max_new_tokens: u32,
) -> Result<usize>;
}
}

0 comments on commit ed5dfbf

Please sign in to comment.