Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate ability to use APIKeys #411

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deploy/ansible/worker/tasks/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@
precheck_endpoint_url != None and
precheck_endpoint_url | trim != ''

- name: Append api key arg for given precheck_endpoint
ansible.builtin.set_fact:
worker_args: "{{ worker_args }} --precheck-api-key {{ precheck_api_key }}"
when: precheck_api_key is defined and
precheck_api_key != None and
precheck_api_key | trim != ''

- name: Set tls-insecure if enabled
ansible.builtin.set_fact:
worker_args: "{{ worker_args }} --tls-insecure \"true\""
Expand Down
11 changes: 10 additions & 1 deletion worker/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ var (
TlsClientCertPath string
TlsClientKeyPath string
TlsServerCaCertPath string
PrecheckAPIKey string
TlsInsecure bool
MaxSeed int
TaxonomyFolders = []string{"compositional_skills", "knowledge"}
Expand Down Expand Up @@ -79,6 +80,7 @@ type Worker struct {
logger *zap.SugaredLogger
job string
precheckEndpoint string
precheckAPIKey string
sdgEndpoint string
jobStart time.Time
tlsClientCertPath string
Expand All @@ -88,14 +90,15 @@ type Worker struct {
cmdRun string
}

func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, precheckAPIKey, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
return &Worker{
ctx: ctx,
pool: pool,
svc: svc,
logger: logger,
job: job,
precheckEndpoint: precheckEndpoint,
precheckAPIKey: precheckAPIKey,
sdgEndpoint: sdgEndpoint,
jobStart: time.Now(),
tlsClientCertPath: tlsClientCertPath,
Expand All @@ -115,6 +118,7 @@ func init() {
generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in")
generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory")
generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
generateCmd.Flags().StringVarP(&PrecheckAPIKey, "precheck-api-key", "", "", "The APIKey for the precheck-endpoint-url.")
generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.")
generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate")
generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo")
Expand Down Expand Up @@ -201,6 +205,7 @@ var generateCmd = &cobra.Command{
}
NewJobProcessor(ctx, pool, svc, sugar, job,
PreCheckEndpointURL,
PrecheckAPIKey,
SdgEndpointURL,
TlsClientCertPath,
TlsClientKeyPath,
Expand Down Expand Up @@ -432,6 +437,10 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
if PreCheckEndpointURL != localEndpoint && modelName != "unknown" {
commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckEndpointURL, modelName)
}
if PrecheckAPIKey != "" {
commandStr += fmt.Sprintf(" --precheck-api-key %s", PrecheckAPIKey)
}

cmdArgs := strings.Fields(commandStr)
cmd := exec.Command(lab, cmdArgs...)
// Register the command for reporting/logging
Expand Down
2 changes: 2 additions & 0 deletions worker/cmd/generate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ func TestFetchModelName(t *testing.T) {
zap.NewExample().Sugar(),
"job-id",
mockServer.URL,
"precheck-api-key",
"http://sdg-example.com",
"dummy-client-cert-path.pem",
"dummy-client-key-path.pem",
Expand Down Expand Up @@ -214,6 +215,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) {
zap.NewExample().Sugar(),
"job-id",
mockServer.URL,
"precheck-api-key",
"http://sdg-example.com",
"dummy-client-cert-path.pem",
"dummy-client-key-path.pem",
Expand Down
Loading