diff --git a/deploy/ansible/worker/tasks/main.yml b/deploy/ansible/worker/tasks/main.yml index 0cf493f4..df3fa1e3 100644 --- a/deploy/ansible/worker/tasks/main.yml +++ b/deploy/ansible/worker/tasks/main.yml @@ -52,6 +52,13 @@ precheck_endpoint_url != None and precheck_endpoint_url | trim != '' +- name: Append api key arg for given precheck_endpoint + ansible.builtin.set_fact: + worker_args: "{{ worker_args }} --precheck-api-key {{ precheck_api_key }}" + when: precheck_api_key is defined and + precheck_api_key != None and + precheck_api_key | trim != '' + - name: Set tls-insecure if enabled ansible.builtin.set_fact: worker_args: "{{ worker_args }} --tls-insecure \"true\"" diff --git a/worker/cmd/generate.go b/worker/cmd/generate.go index 4fcc6500..6adf78cb 100644 --- a/worker/cmd/generate.go +++ b/worker/cmd/generate.go @@ -46,6 +46,7 @@ var ( TlsClientCertPath string TlsClientKeyPath string TlsServerCaCertPath string + PrecheckAPIKey string TlsInsecure bool MaxSeed int TaxonomyFolders = []string{"compositional_skills", "knowledge"} @@ -79,6 +80,7 @@ type Worker struct { logger *zap.SugaredLogger job string precheckEndpoint string + precheckAPIKey string sdgEndpoint string jobStart time.Time tlsClientCertPath string @@ -88,7 +90,7 @@ type Worker struct { cmdRun string } -func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker { +func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, precheckAPIKey, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker { return &Worker{ ctx: ctx, pool: pool, @@ -96,6 +98,7 @@ func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logg logger: logger, job: job, precheckEndpoint: precheckEndpoint, + precheckAPIKey: precheckAPIKey, sdgEndpoint: sdgEndpoint, jobStart: time.Now(), tlsClientCertPath: tlsClientCertPath, @@ -115,6 +118,7 @@ func init() { generateCmd.Flags().StringVarP(&WorkDir, "work-dir", "w", "", "Directory to work in") generateCmd.Flags().StringVarP(&VenvDir, "venv-dir", "v", "", "The virtual environment directory") generateCmd.Flags().StringVarP(&PreCheckEndpointURL, "precheck-endpoint-url", "e", "", "Endpoint hosting the model API. Default, it assumes the model is served locally.") + generateCmd.Flags().StringVarP(&PrecheckAPIKey, "precheck-api-key", "", "", "The APIKey for the precheck-endpoint-url.") generateCmd.Flags().StringVarP(&SdgEndpointURL, "sdg-endpoint-url", "", "http://localhost:8000/v1", "Endpoint hosting the model API. Default, it assumes the model is served locally.") generateCmd.Flags().IntVarP(&NumInstructions, "num-instructions", "n", 10, "The number of instructions to generate") generateCmd.Flags().StringVarP(&GitRemote, "git-remote", "", "https://github.com/instructlab/taxonomy", "The git remote for the taxonomy repo") @@ -201,6 +205,7 @@ var generateCmd = &cobra.Command{ } NewJobProcessor(ctx, pool, svc, sugar, job, PreCheckEndpointURL, + PrecheckAPIKey, SdgEndpointURL, TlsClientCertPath, TlsClientKeyPath, @@ -432,6 +437,10 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error { if PreCheckEndpointURL != localEndpoint && modelName != "unknown" { commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckEndpointURL, modelName) } + if PrecheckAPIKey != "" { + commandStr += fmt.Sprintf(" --precheck-api-key %s", PrecheckAPIKey) + } + cmdArgs := strings.Fields(commandStr) cmd := exec.Command(lab, cmdArgs...) // Register the command for reporting/logging diff --git a/worker/cmd/generate_test.go b/worker/cmd/generate_test.go index 6102c18e..041603c6 100644 --- a/worker/cmd/generate_test.go +++ b/worker/cmd/generate_test.go @@ -153,6 +153,7 @@ func TestFetchModelName(t *testing.T) { zap.NewExample().Sugar(), "job-id", mockServer.URL, + "precheck-api-key", "http://sdg-example.com", "dummy-client-cert-path.pem", "dummy-client-key-path.pem", @@ -214,6 +215,7 @@ func TestFetchModelNameWithInvalidObject(t *testing.T) { zap.NewExample().Sugar(), "job-id", mockServer.URL, + "precheck-api-key", "http://sdg-example.com", "dummy-client-cert-path.pem", "dummy-client-key-path.pem",