From 1b0b3ba586704cba65b7b65eec41ed0f256719e0 Mon Sep 17 00:00:00 2001 From: Gauri Yadav Date: Mon, 29 Jun 2026 12:01:07 +0530 Subject: [PATCH] XRAY-145929 - Onboard hugging face for jf curation-audit --- cli/docs/flags.go | 6 +- cli/docs/scan/curation/help.go | 3 +- cli/scancommands.go | 1 + commands/curation/curationaudit.go | 158 +++++++- commands/curation/curationaudit_test.go | 93 +++++ sca/bom/buildinfo/buildinfobom.go | 8 + sca/bom/buildinfo/technologies/common.go | 8 + .../huggingface/discovery/notebook.go | 91 +++++ .../huggingface/discovery/notebook_test.go | 88 ++++ .../huggingface/discovery/python.go | 383 ++++++++++++++++++ .../huggingface/discovery/python_test.go | 159 ++++++++ .../huggingface/discovery/result.go | 57 +++ .../huggingface/discovery/scanner.go | 155 +++++++ .../huggingface/discovery/scanner_test.go | 93 +++++ .../technologies/huggingface/huggingface.go | 229 +++++++++++ .../huggingface/huggingface_test.go | 243 +++++++++++ utils/techutils/techutils.go | 22 +- 17 files changed, 1785 insertions(+), 12 deletions(-) create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/notebook.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/notebook_test.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/python.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/python_test.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/result.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/scanner.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/discovery/scanner_test.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/huggingface.go create mode 100644 sca/bom/buildinfo/technologies/huggingface/huggingface_test.go diff --git a/cli/docs/flags.go b/cli/docs/flags.go index 9fdf48e1c..788d91df5 100644 --- a/cli/docs/flags.go +++ b/cli/docs/flags.go @@ -160,6 +160,7 @@ const ( // Unique curation flags CurationOutput = "curation-format" DockerImageName = "image" + HuggingFaceModel = "hugging-face-model" SolutionPath = "solution-path" IncludeCachedPackages = "include-cached-packages" LegacyPeerDeps = "legacy-peer-deps" @@ -228,7 +229,7 @@ var commandFlags = map[string][]string{ StaticSca, XrayLibPluginBinaryCustomPath, AnalyzerManagerCustomPath, AddSastRules, }, CurationAudit: { - CurationOutput, WorkingDirs, Threads, RequirementsFile, InsecureTls, useWrapperAudit, UseIncludedBuilds, SolutionPath, DockerImageName, IncludeCachedPackages, MvnIncludePluginDeps, LegacyPeerDeps, RunNative, + CurationOutput, WorkingDirs, Threads, RequirementsFile, InsecureTls, useWrapperAudit, UseIncludedBuilds, SolutionPath, DockerImageName, HuggingFaceModel, IncludeCachedPackages, MvnIncludePluginDeps, LegacyPeerDeps, RunNative, }, GitCountContributors: { InputFile, ScmType, ScmApiUrl, Token, Owner, RepoName, Months, DetailedSummary, InsecureTls, GitThreads, CacheValidity, @@ -370,7 +371,8 @@ var flagsMap = map[string]components.Flag{ UseConfigProfile: components.NewBoolFlag(UseConfigProfile, "Set to false to override config profile for the audit.", components.WithBoolDefaultValue(true), components.SetHiddenBoolFlag()), // Docker flags - DockerImageName: components.NewStringFlag(DockerImageName, "Specifies the Docker image name to audit. Uses the same format as the Docker CLI, including Artifactory-hosted images."), + DockerImageName: components.NewStringFlag(DockerImageName, "Specifies the Docker image name to audit. Uses the same format as the Docker CLI, including Artifactory-hosted images."), + HuggingFaceModel: components.NewStringFlag(HuggingFaceModel, "Specifies one or more Hugging Face models or datasets to audit, comma-separated, in the format ':' (e.g. 'mcpotato/42-eicar-street:main,bert-base-uncased'). The revision is optional and defaults to 'main' when omitted. These models are audited in addition to any auto-discovered from source. The Artifactory repository is read from the HF_ENDPOINT environment variable."), // Git flags InputFile: components.NewStringFlag(InputFile, "Path to an input file in YAML format contains multiple git providers. With this option, all other scm flags will be ignored and only git servers mentioned in the file will be examined.."), diff --git a/cli/docs/scan/curation/help.go b/cli/docs/scan/curation/help.go index 182736e7e..43daca1be 100644 --- a/cli/docs/scan/curation/help.go +++ b/cli/docs/scan/curation/help.go @@ -14,7 +14,7 @@ When to use: Prerequisites: - A configured JFrog Platform server (jf c add) with JFrog Curation entitlement. -- Project must use a supported package manager (npm, yarn, pip, maven, gradle, nuget, go) resolved through a curation-configured remote. +- Project must use a supported package manager (npm, yarn, pip, maven, gradle, nuget, go) resolved through a curation-configured remote. Docker images and Hugging Face models/datasets are audited via dedicated flags. - The package manager and its lockfile must be present in the working directory. Common patterns: @@ -23,6 +23,7 @@ Common patterns: $ jf curation-audit --format=json --threads=4 $ jf curation-audit --requirements-file=requirements-dev.txt $ jf curation-audit --docker-image=my-image:tag + $ HF_ENDPOINT=https://my.jfrog.io/artifactory/api/huggingfaceml/my-hf-repo jf curation-audit --hugging-face-model=org/model:main Gotchas: - The user/token must be entitled for Curation; otherwise the command exits with an entitlement notice. diff --git a/cli/scancommands.go b/cli/scancommands.go index 05ba99a71..e29daf783 100644 --- a/cli/scancommands.go +++ b/cli/scancommands.go @@ -754,6 +754,7 @@ func getCurationCommand(c *components.Context) (*curation.CurationAuditCommand, SetPipRequirementsFile(c.GetStringFlagValue(flags.RequirementsFile)). SetSolutionFilePath(c.GetStringFlagValue(flags.SolutionPath)) curationAuditCommand.SetDockerImageName(c.GetStringFlagValue(flags.DockerImageName)) + curationAuditCommand.SetHuggingFaceModel(c.GetStringFlagValue(flags.HuggingFaceModel)) curationAuditCommand.SetIncludeCachedPackages(c.GetBoolFlagValue(flags.IncludeCachedPackages)) curationAuditCommand.SetMvnIncludePluginDeps(c.GetBoolFlagValue(flags.MvnIncludePluginDeps)) curationAuditCommand.SetLegacyPeerDeps(c.GetBoolFlagValue(flags.LegacyPeerDeps)) diff --git a/commands/curation/curationaudit.go b/commands/curation/curationaudit.go index 1f51ae446..f206184f4 100644 --- a/commands/curation/curationaudit.go +++ b/commands/curation/curationaudit.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "fmt" + "io/fs" "net/http" "os" "path/filepath" @@ -39,6 +40,7 @@ import ( "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/docker" + "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/huggingface" npmtech "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/npm" pnpmtech "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/pnpm" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/python" @@ -93,6 +95,11 @@ const ( "are blocked by the curation policy. Details of the policy violations are shown in the table below.\n" + "Dependency analysis cannot proceed until these issues are addressed.\n" + "Once you apply a waiver or switch to an approved version and re-run the audit, additional results will be available." + + // hfUnresolvedReportKey is the results-map key for a Hugging Face scan that found + // only unresolved (dynamic) references and no statically-resolvable models, so it + // has warnings to surface but no curation table of its own. + hfUnresolvedReportKey = "huggingface (unresolved references)" ) var CurationOutputFormats = []string{string(outFormat.Table), string(outFormat.Json)} @@ -120,7 +127,8 @@ var supportedTech = map[techutils.Technology]func(ca *CurationAuditCommand) (boo techutils.Gem: func(ca *CurationAuditCommand) (bool, error) { return ca.checkSupportByVersionOrEnv(techutils.Gem, MinArtiGradleGemSupport) }, - techutils.Docker: func(ca *CurationAuditCommand) (bool, error) { return true, nil }, + techutils.Docker: func(ca *CurationAuditCommand) (bool, error) { return true, nil }, + techutils.HuggingFaceMl: func(ca *CurationAuditCommand) (bool, error) { return true, nil }, techutils.Poetry: func(ca *CurationAuditCommand) (bool, error) { return ca.checkSupportByVersionOrEnv(techutils.Poetry, MinArtiPassThroughSupport) }, @@ -248,6 +256,7 @@ type CurationAuditCommand struct { OriginPath string parallelRequests int dockerImageName string + huggingFaceModel string includeCachedPackages bool mvnIncludePluginDeps bool audit.AuditParamsInterface @@ -261,6 +270,11 @@ type CurationReport struct { // was produced via the metadata-API fallback. The partial-report warning // is printed after the spinner stops so it is not swallowed by the spinner. isPartial bool + // warnings carries non-fatal, user-facing messages produced while building + // the dependency tree (e.g. Hugging Face references that could not be + // statically resolved). They are printed after the curation table so the + // coverage gap stays visible instead of being buried in BOM-build output. + warnings []string } type WaiverResponse struct { @@ -300,6 +314,15 @@ func (ca *CurationAuditCommand) SetDockerImageName(dockerImageName string) *Cura return ca } +func (ca *CurationAuditCommand) HuggingFaceModel() string { + return ca.huggingFaceModel +} + +func (ca *CurationAuditCommand) SetHuggingFaceModel(huggingFaceModel string) *CurationAuditCommand { + ca.huggingFaceModel = huggingFaceModel + return ca +} + func (ca *CurationAuditCommand) SetIncludeCachedPackages(includeCachedPackages bool) *CurationAuditCommand { ca.includeCachedPackages = includeCachedPackages return ca @@ -362,6 +385,13 @@ func (ca *CurationAuditCommand) Run() (err error) { for projectPath, packagesStatus := range results { err = errors.Join(err, printResult(ca.OutputFormat(), projectPath, packagesStatus.packagesStatus)) + // Surface tree-build warnings (e.g. Hugging Face references that could not be + // statically resolved) after the table, so the coverage gap is the last thing + // the user sees rather than being buried in the BOM-build output above. + for _, w := range packagesStatus.warnings { + log.Warn(w) + } + for _, ps := range packagesStatus.packagesStatus { if ps.WaiverAllowed && !utils.IsCI() { // If at least one package allows waiver requests, we will ask the user if they want to request a waiver @@ -469,6 +499,19 @@ func (ca *CurationAuditCommand) doCurateAudit(results map[string]*CurationReport log.Debug(fmt.Sprintf("Docker image name '%s' was provided, running Docker curation audit.", ca.DockerImageName())) techs = []string{techutils.Docker.String()} } + // --hugging-face-model: explicit spot-check — run HF only (skip pip/npm/etc. so + // the user gets a fast single-model verdict without waiting for full dep resolution). + // Auto-discovery (HF_ENDPOINT set + .py files present): additive — HF runs alongside + // the detected package managers so a full audit still covers both surfaces. + if ca.HuggingFaceModel() != "" { + log.Debug(fmt.Sprintf("Hugging Face models '%s' were provided explicitly — running HF-only audit.", ca.HuggingFaceModel())) + techs = []string{techutils.HuggingFaceMl.String()} + } else if os.Getenv("HF_ENDPOINT") != "" && hasPythonFiles(ca.OriginPath) { + // Auto-discovery: attempt an HF source scan when HF_ENDPOINT is configured and + // .py/.ipynb files are present. BuildDependencyTree returns gracefully if no + // HF call sites are found. + techs = appendIfMissing(techs, techutils.HuggingFaceMl.String()) + } // Resolve npm→yarn when the project was configured with 'jf yarn-config' (yarn.yaml exists) // but has no yarn.lock/.yarnrc.yml so the file-based detector picked npm instead. for i, tech := range techs { @@ -500,6 +543,44 @@ func (ca *CurationAuditCommand) doCurateAudit(results map[string]*CurationReport return nil } +// appendIfMissing appends value to slice only if it is not already present, +// keeping the technology list free of duplicates when a tech is both detected +// and requested explicitly. +func appendIfMissing(slice []string, value string) []string { + for _, v := range slice { + if v == value { + return slice + } + } + return append(slice, value) +} + +// hasPythonFiles returns true if dir contains at least one .py or .ipynb file, +// indicating the project may have Hugging Face model references to discover. +func hasPythonFiles(dir string) bool { + if dir == "" { + dir = "." + } + found := false + _ = filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error { + if err != nil || found { + return nil + } + if d.IsDir() { + if n := d.Name(); n == ".git" || n == ".venv" || n == "venv" || n == "__pycache__" || n == "node_modules" { + return filepath.SkipDir + } + return nil + } + ext := strings.ToLower(filepath.Ext(path)) + if ext == ".py" || ext == ".ipynb" { + found = true + } + return nil + }) + return found +} + // resolveNpmYarnTech upgrades npm→yarn when the project has yarn.yaml but no npm.yaml — // the developer ran 'jf yarn-config' but the file-system detector fell back to npm. func resolveNpmYarnTech(tech string) string { @@ -596,6 +677,9 @@ func (ca *CurationAuditCommand) getBuildInfoParamsByTech() (technologies.BuildIn PipRequirementsFile: ca.PipRequirementsFile(), // Docker params DockerImageName: ca.DockerImageName(), + // Hugging Face params + HuggingFaceModel: ca.HuggingFaceModel(), + WorkingDirectory: ca.OriginPath, // NuGet params SolutionFilePath: ca.SolutionFilePath(), }, err @@ -639,6 +723,17 @@ func (ca *CurationAuditCommand) auditTree(tech techutils.Technology, results map } // Validate the graph isn't empty. if len(depTreeResult.FullDepTrees) == 0 { + // For HuggingFace auto-discovery, no models found is a normal outcome + // (the project has .py files but no HF call sites) — not an error. + if tech == techutils.HuggingFaceMl { + log.Debug("Hugging Face: no model references discovered in source — skipping HF curation probe") + // There may still be unresolved (dynamic) references worth surfacing + // even when nothing statically resolved to a curation probe. + if len(depTreeResult.Warnings) > 0 { + results[hfUnresolvedReportKey] = &CurationReport{warnings: depTreeResult.Warnings} + } + return nil + } return errorutils.CheckErrorf("found no dependencies for the audited project using '%v' as the package manager", tech.String()) } rtManager, serverDetails, err := ca.getRtManagerAndAuth(tech) @@ -707,6 +802,7 @@ func (ca *CurationAuditCommand) auditTree(tech techutils.Technology, results map packagesStatus: packagesStatus, // We subtract 1 because the root node is not a package. totalNumberOfPackages: len(depTreeResult.FlatTree.Nodes) - 1, + warnings: depTreeResult.Warnings, } return err } @@ -925,6 +1021,17 @@ func (ca *CurationAuditCommand) SetRepo(tech techutils.Technology) error { return nil } + // Hugging Face resolves its Artifactory repo from the --hugging-face-model reference, + // not from a 'jf -config' yaml file. + if tech == techutils.HuggingFaceMl { + repoConfig, err := huggingface.GetHuggingFaceRepositoryConfig() + if err != nil { + return err + } + ca.setPackageManagerConfig(repoConfig) + return nil + } + // When --run-native is set for npm, read the Artifactory URL and repo name from the // project's .npmrc via native npm config — no jf npm-config/npm.yaml required. if ca.RunNative() && tech == techutils.Npm { @@ -1184,6 +1291,14 @@ func (nc *treeAnalyzer) fetchNodeStatus(node xrayUtils.GraphNode, p *sync.Map) e } return nil } + // Hugging Face: a 404 means the model/dataset+revision is not resolvable through the + // proxy (e.g. uncached with on-demand repositories disabled, an unknown revision, or a + // dataset the catalog does not track). That is not a curation block and must not fail + // the whole audit — treat it as "not blocked" and move on, like the NuGet 404 skip below. + if resp != nil && resp.StatusCode == http.StatusNotFound && nc.tech == techutils.HuggingFaceMl { + log.Debug(fmt.Sprintf("Hugging Face: %s:%s not resolvable at %s (HTTP 404) — skipping", name, version, packageUrl)) + continue + } if err != nil { if resp != nil && resp.StatusCode >= 400 { return errorutils.CheckErrorf(errorTemplateHeadRequest, packageUrl, name, version, resp.StatusCode, err) @@ -1617,6 +1732,9 @@ func getUrlNameAndVersionByTech(tech techutils.Technology, node *xrayUtils.Graph case techutils.Docker: downloadUrls, name, version = getDockerNameAndVersion(node.Id, artiUrl, repo) return + case techutils.HuggingFaceMl: + downloadUrls, name, version = getHuggingFaceNameAndVersion(node.Id, artiUrl, repo) + return } return } @@ -1862,6 +1980,44 @@ func getDockerNameAndVersion(id, artiUrl, repo string) (downloadUrls []string, n return } +// getHuggingFaceNameAndVersion extracts the model id and revision from a node id of the +// form "huggingfaceml://:" and builds the model-info probe URL. +// +// The probe targets the model metadata endpoint, which the curation service blocks +// (HEAD returns 403) for a malicious revision — independent of any specific file: +// +// {artiUrl}/api/huggingfaceml/{repo}/api/models/{repo_id}/revision/{revision} +func getHuggingFaceNameAndVersion(id, artiUrl, repo string) (downloadUrls []string, name, version string) { + if id == "" { + return + } + id = strings.TrimPrefix(id, huggingface.HuggingFacePackagePrefix) + + // Datasets are probed via api/datasets/ instead of api/models/. The repo type is + // carried by an optional "dataset|" marker placed right after the scheme prefix. + repoTypePath := "api/models" + if strings.HasPrefix(id, huggingface.DatasetNodeMarker) { + repoTypePath = "api/datasets" + id = strings.TrimPrefix(id, huggingface.DatasetNodeMarker) + } + + // The repo id (e.g. "mcpotato/42-eicar-street") contains '/' but never ':'; the + // revision suffix never contains '/', so split on the last ':'. + if idx := strings.LastIndex(id, ":"); idx > 0 && !strings.Contains(id[idx+1:], "/") { + name = id[:idx] + version = id[idx+1:] + } else { + name = id + version = huggingface.DefaultRevision + } + + if artiUrl != "" && repo != "" { + downloadUrls = []string{fmt.Sprintf("%s/api/huggingfaceml/%s/%s/%s/revision/%s", + strings.TrimSuffix(artiUrl, "/"), repo, repoTypePath, name, version)} + } + return +} + func GetCurationOutputFormat(formatFlagVal string) (format outFormat.OutputFormat, err error) { // Default print format is table. format = outFormat.Table diff --git a/commands/curation/curationaudit_test.go b/commands/curation/curationaudit_test.go index 0aefd5b66..4b80add14 100644 --- a/commands/curation/curationaudit_test.go +++ b/commands/curation/curationaudit_test.go @@ -1386,6 +1386,99 @@ func Test_getDockerNameAndVersion(t *testing.T) { }) } } +func Test_getHuggingFaceNameAndVersion(t *testing.T) { + tests := []struct { + name string + id string + artiUrl string + repo string + wantDownloadUrls []string + wantName string + wantVersion string + }{ + { + name: "model with explicit sha revision", + id: "huggingfaceml://mcpotato/42-eicar-street:8fb61c4d511e9aaff0ea55396a124aa292830efc", + artiUrl: "https://test.jfrogdev.org/artifactory", + repo: "my-hugging-face-repo", + wantDownloadUrls: []string{"https://test.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo/api/models/mcpotato/42-eicar-street/revision/8fb61c4d511e9aaff0ea55396a124aa292830efc"}, + wantName: "mcpotato/42-eicar-street", + wantVersion: "8fb61c4d511e9aaff0ea55396a124aa292830efc", + }, + { + name: "model with branch revision", + id: "huggingfaceml://bert-base-uncased:main", + artiUrl: "https://test.jfrogdev.org/artifactory", + repo: "my-hugging-face-repo", + wantDownloadUrls: []string{"https://test.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo/api/models/bert-base-uncased/revision/main"}, + wantName: "bert-base-uncased", + wantVersion: "main", + }, + { + name: "model id with no revision defaults to main", + id: "huggingfaceml://org/model", + artiUrl: "https://test.jfrogdev.org/artifactory", + repo: "my-hugging-face-repo", + wantDownloadUrls: []string{"https://test.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo/api/models/org/model/revision/main"}, + wantName: "org/model", + wantVersion: "main", + }, + { + name: "dataset is probed via api/datasets endpoint", + id: "huggingfaceml://dataset|squad:main", + artiUrl: "https://test.jfrogdev.org/artifactory", + repo: "my-hugging-face-repo", + wantDownloadUrls: []string{"https://test.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo/api/datasets/squad/revision/main"}, + wantName: "squad", + wantVersion: "main", + }, + { + name: "dataset with org and explicit revision", + id: "huggingfaceml://dataset|stanfordnlp/squad:v2.0", + artiUrl: "https://test.jfrogdev.org/artifactory", + repo: "my-hugging-face-repo", + wantDownloadUrls: []string{"https://test.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo/api/datasets/stanfordnlp/squad/revision/v2.0"}, + wantName: "stanfordnlp/squad", + wantVersion: "v2.0", + }, + { + name: "empty artiUrl and repo produce no download URL", + id: "huggingfaceml://org/model:main", + artiUrl: "", + repo: "", + wantDownloadUrls: nil, + wantName: "org/model", + wantVersion: "main", + }, + { + name: "empty id returns empty results", + id: "", + artiUrl: "https://test.jfrogdev.org/artifactory", + repo: "my-hugging-face-repo", + wantDownloadUrls: nil, + wantName: "", + wantVersion: "", + }, + { + name: "trailing slash stripped from artiUrl", + id: "huggingfaceml://org/model:v1.0", + artiUrl: "https://test.jfrogdev.org/artifactory/", + repo: "my-hugging-face-repo", + wantDownloadUrls: []string{"https://test.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo/api/models/org/model/revision/v1.0"}, + wantName: "org/model", + wantVersion: "v1.0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotDownloadUrls, gotName, gotVersion := getHuggingFaceNameAndVersion(tt.id, tt.artiUrl, tt.repo) + assert.Equal(t, tt.wantDownloadUrls, gotDownloadUrls, "downloadUrls mismatch") + assert.Equal(t, tt.wantName, gotName, "name mismatch") + assert.Equal(t, tt.wantVersion, gotVersion, "version mismatch") + }) + } +} + func Test_getNugetNameScopeAndVersion(t *testing.T) { tests := []struct { name string diff --git a/sca/bom/buildinfo/buildinfobom.go b/sca/bom/buildinfo/buildinfobom.go index 66fed8596..c55049ef5 100644 --- a/sca/bom/buildinfo/buildinfobom.go +++ b/sca/bom/buildinfo/buildinfobom.go @@ -32,6 +32,7 @@ import ( "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/docker" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/gem" _go "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/go" + "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/huggingface" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/java" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/npm" "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/nuget" @@ -233,6 +234,11 @@ type DependencyTreeResult struct { FlatTree *xrayUtils.GraphNode FullDepTrees []*xrayUtils.GraphNode DownloadUrls map[string]string + // Warnings carries non-fatal, user-facing messages produced while building the + // tree (e.g. Hugging Face references that could not be statically resolved). + // They are surfaced to the user after the curation tables rather than during + // the noisy BOM-build phase, so coverage gaps stay visible instead of silent. + Warnings []string } func GetTechDependencyTree(params technologies.BuildInfoBomGeneratorParams, artifactoryServerDetails *config.ServerDetails, tech techutils.Technology) (depTreeResult DependencyTreeResult, err error) { @@ -298,6 +304,8 @@ func GetTechDependencyTree(params technologies.BuildInfoBomGeneratorParams, arti depTreeResult.FullDepTrees, uniqueDepsIds, err = swift.BuildDependencyTree(params) case techutils.Docker: depTreeResult.FullDepTrees, uniqueDepsIds, err = docker.BuildDependencyTree(params) + case techutils.HuggingFaceMl: + depTreeResult.FullDepTrees, uniqueDepsIds, depTreeResult.Warnings, err = huggingface.BuildDependencyTree(params) default: err = errorutils.CheckErrorf("%s is currently not supported", string(tech)) } diff --git a/sca/bom/buildinfo/technologies/common.go b/sca/bom/buildinfo/technologies/common.go index cdf6a0a4f..4c936c6e4 100644 --- a/sca/bom/buildinfo/technologies/common.go +++ b/sca/bom/buildinfo/technologies/common.go @@ -75,6 +75,14 @@ type BuildInfoBomGeneratorParams struct { MaxTreeDepth string // Docker params DockerImageName string + // Hugging Face params + // HuggingFaceModel is the model/dataset reference to audit, e.g. "mcpotato/42-eicar-street:main". + // Set by the curation command's --hugging-face-model flag. + // When empty, BuildDependencyTree auto-discovers references from WorkingDirectory. + HuggingFaceModel string + // WorkingDirectory is the project root used for Hugging Face auto-discovery. + // Defaults to "." when empty. + WorkingDirectory string // NuGet params SolutionFilePath string } diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/notebook.go b/sca/bom/buildinfo/technologies/huggingface/discovery/notebook.go new file mode 100644 index 000000000..308adcd4a --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/notebook.go @@ -0,0 +1,91 @@ +package discovery + +import ( + "encoding/json" + "fmt" + "os" + "strings" +) + +// ipynbNotebook is the minimal structure we need from a .ipynb JSON file. +type ipynbNotebook struct { + Cells []ipynbCell `json:"cells"` +} + +type ipynbCell struct { + CellType string `json:"cell_type"` + // Source is either a []string (list of lines) or a single string, depending + // on the nbformat version; json.RawMessage lets us handle both. + Source json.RawMessage `json:"source"` +} + +// ParseNotebook reads a .ipynb file and returns discovered/unresolved entries. +// Code cells are extracted and fed through the Python scanner. Cells that +// contain notebook magic commands (leading ! or %) are cleaned before parsing. +func ParseNotebook(path string) (discovered []DiscoveredModel, unresolved []UnresolvedSite, err error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, fmt.Errorf("reading notebook %s: %w", path, err) + } + return parseNotebookBytes(data, path) +} + +// parseNotebookBytes is the testable core of ParseNotebook. +func parseNotebookBytes(data []byte, filename string) (discovered []DiscoveredModel, unresolved []UnresolvedSite, err error) { + var nb ipynbNotebook + if err = json.Unmarshal(data, &nb); err != nil { + return nil, nil, fmt.Errorf("parsing notebook JSON %s: %w", filename, err) + } + + for cellIdx, cell := range nb.Cells { + if cell.CellType != "code" { + continue + } + src, parseErr := cellSource(cell.Source) + if parseErr != nil { + // Malformed cell — skip rather than abort. + continue + } + if strings.TrimSpace(src) == "" { + continue + } + src = stripMagics(src) + cellFile := fmt.Sprintf("%s#cell-%d", filename, cellIdx) + d, u := ParsePythonSource(src, cellFile, nil) + discovered = append(discovered, d...) + unresolved = append(unresolved, u...) + } + return +} + +// cellSource decodes the source field which can be a JSON string or []string. +func cellSource(raw json.RawMessage) (string, error) { + // Try []string first (most common in nbformat 4). + var lines []string + if err := json.Unmarshal(raw, &lines); err == nil { + return strings.Join(lines, ""), nil + } + // Fall back to a single string. + var s string + if err := json.Unmarshal(raw, &s); err != nil { + return "", err + } + return s, nil +} + +// stripMagics removes lines that start with ! or % (IPython magic commands +// and shell escapes) so the Python parser doesn't choke on them. +// Lines starting with # (comments) are left intact. +func stripMagics(src string) string { + lines := strings.Split(src, "\n") + out := make([]string, 0, len(lines)) + for _, l := range lines { + trimmed := strings.TrimLeft(l, " \t") + if strings.HasPrefix(trimmed, "!") || strings.HasPrefix(trimmed, "%") { + out = append(out, "") // preserve line numbers + } else { + out = append(out, l) + } + } + return strings.Join(out, "\n") +} diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/notebook_test.go b/sca/bom/buildinfo/technologies/huggingface/discovery/notebook_test.go new file mode 100644 index 000000000..aa5f6c01f --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/notebook_test.go @@ -0,0 +1,88 @@ +package discovery + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseNotebook_CodeCell(t *testing.T) { + nb := `{ + "cells": [ + { + "cell_type": "markdown", + "source": ["# Heading"] + }, + { + "cell_type": "code", + "source": ["from transformers import AutoModel\n", "model = AutoModel.from_pretrained(\"org/model\", revision=\"v1\")\n"] + } + ] +}` + disc, unres, err := parseNotebookBytes([]byte(nb), "test.ipynb") + require.NoError(t, err) + require.Len(t, disc, 1) + assert.Equal(t, "org/model", disc[0].RepoID) + assert.Equal(t, "v1", disc[0].Revision) + assert.Contains(t, disc[0].Sources[0].File, "test.ipynb#cell-") + assert.Empty(t, unres) +} + +func TestParseNotebook_MagicsStripped(t *testing.T) { + nb := `{ + "cells": [ + { + "cell_type": "code", + "source": ["!pip install transformers\n", "%load_ext autoreload\n", "from_pretrained(\"org/model\")\n"] + } + ] +}` + disc, _, err := parseNotebookBytes([]byte(nb), "nb.ipynb") + require.NoError(t, err) + require.Len(t, disc, 1) + assert.Equal(t, "org/model", disc[0].RepoID) +} + +func TestParseNotebook_MarkdownCellSkipped(t *testing.T) { + nb := `{ + "cells": [ + { + "cell_type": "markdown", + "source": ["snapshot_download(repo_id=\"org/should-not-match\")"] + } + ] +}` + disc, _, err := parseNotebookBytes([]byte(nb), "nb.ipynb") + require.NoError(t, err) + assert.Empty(t, disc) +} + +func TestParseNotebook_InvalidJSON(t *testing.T) { + _, _, err := parseNotebookBytes([]byte("not json"), "bad.ipynb") + require.Error(t, err) +} + +func TestStripMagics(t *testing.T) { + src := "!pip install foo\n%load_ext bar\nimport torch\n" + out := stripMagics(src) + lines := splitLines(out) + assert.Equal(t, "", lines[0]) + assert.Equal(t, "", lines[1]) + assert.Equal(t, "import torch", lines[2]) +} + +func splitLines(s string) []string { + var out []string + start := 0 + for i := 0; i < len(s); i++ { + if s[i] == '\n' { + out = append(out, s[start:i]) + start = i + 1 + } + } + if start < len(s) { + out = append(out, s[start:]) + } + return out +} diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/python.go b/sca/bom/buildinfo/technologies/huggingface/discovery/python.go new file mode 100644 index 000000000..5c93c98ba --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/python.go @@ -0,0 +1,383 @@ +package discovery + +import ( + "regexp" + "strings" +) + +// callPattern describes one of the HF call signatures we recognise. +type callPattern struct { + // namePattern is matched against the bare function/method name (suffix after the last dot). + namePattern *regexp.Regexp + // repoIDArgName is the keyword-arg name for repo_id (empty = first positional only). + repoIDArgName string + // revisionArgName is the keyword-arg name for the revision. + revisionArgName string + // repoTypeArgName is the keyword-arg name for repo_type (empty = fixed). + repoTypeArgName string + // defaultRepoType is used when repoTypeArgName is empty or not present in the call. + defaultRepoType RepoType +} + +// callName extracts the bare function/method name from a raw call token. +// e.g. "AutoModel.from_pretrained" → "from_pretrained", "snapshot_download" → "snapshot_download". +func callName(s string) string { + if idx := strings.LastIndex(s, "."); idx >= 0 { + return s[idx+1:] + } + return s +} + +var knownCalls = []callPattern{ + { + namePattern: regexp.MustCompile(`^from_pretrained$`), + repoIDArgName: "pretrained_model_name_or_path", + revisionArgName: "revision", + defaultRepoType: RepoTypeModel, + }, + { + namePattern: regexp.MustCompile(`^snapshot_download$`), + repoIDArgName: "repo_id", + revisionArgName: "revision", + repoTypeArgName: "repo_type", + defaultRepoType: RepoTypeModel, + }, + { + namePattern: regexp.MustCompile(`^hf_hub_download$`), + repoIDArgName: "repo_id", + revisionArgName: "revision", + repoTypeArgName: "repo_type", + defaultRepoType: RepoTypeModel, + }, + { + namePattern: regexp.MustCompile(`^load_dataset$`), + repoIDArgName: "", + revisionArgName: "revision", + defaultRepoType: RepoTypeDataset, + }, +} + +// argValue represents one resolved argument value. +type argValue struct { + literal string // non-empty when the value is a string literal + isDynamic bool // true when value is non-literal (Name not in const table, f-string, call, etc.) + isFString bool // sub-kind of dynamic: f-string (for a more specific reason string) +} + +// ParsePythonSource scans a Python source string and returns all discovered HF +// references. filename is used for Location.File in the results. +// constTable allows callers (e.g. a multi-file scanner) to pass in pre-populated +// module-level constants; pass nil and it will be built from src itself. +func ParsePythonSource(src, filename string, constTable map[string]string) (discovered []DiscoveredModel, unresolved []UnresolvedSite) { + if constTable == nil { + constTable = buildConstTable(src) + } + logicalLines := joinContinuationLines(src) + for _, ll := range logicalLines { + d, u := matchLogicalLine(ll.text, ll.startLine, filename, constTable) + discovered = append(discovered, d...) + unresolved = append(unresolved, u...) + } + return +} + +// ---- constant table ------------------------------------------------------- + +// simpleAssignRe matches: NAME = "literal" or NAME = 'literal' +// Anchored to avoid matching mid-expression assignments. +var simpleAssignRe = regexp.MustCompile(`(?m)^[ \t]*([A-Za-z_][A-Za-z0-9_]*)\s*=\s*["']([^"']+)["'][ \t]*(?:#.*)?$`) + +// buildConstTable performs a single pass over src and collects module-level +// single-assignment string constants. A name that is assigned more than once +// is removed from the table (we won't trust it). +func buildConstTable(src string) map[string]string { + table := map[string]string{} + seen := map[string]bool{} + for _, m := range simpleAssignRe.FindAllStringSubmatch(src, -1) { + name, val := m[1], m[2] + if seen[name] { + delete(table, name) // reassigned → not constant + } else { + seen[name] = true + table[name] = val + } + } + return table +} + +// ---- logical line joining -------------------------------------------------- + +type logicalLine struct { + text string + startLine int // 1-based line number of the first physical line +} + +// joinContinuationLines collapses backslash continuations and open-paren spans +// into single logical lines, preserving the starting line number of each. +func joinContinuationLines(src string) []logicalLine { + physical := strings.Split(src, "\n") + var result []logicalLine + var buf strings.Builder + startLine := 0 + depth := 0 + + flush := func(endIdx int) { + if buf.Len() > 0 { + result = append(result, logicalLine{text: buf.String(), startLine: startLine + 1}) + buf.Reset() + } + } + + for i, line := range physical { + trimmed := strings.TrimRight(line, " \t\r") + if buf.Len() == 0 { + startLine = i + } + // Count unquoted parens (rough but good enough for our call patterns) + depth += countParenDepthChange(trimmed) + + if strings.HasSuffix(trimmed, "\\") { + buf.WriteString(strings.TrimSuffix(trimmed, "\\")) + continue + } + buf.WriteString(trimmed) + if depth <= 0 { + depth = 0 + flush(i) + } else { + buf.WriteString(" ") + } + } + flush(len(physical)) + return result +} + +// countParenDepthChange returns the net change in paren depth ignoring quoted strings. +// An unquoted '#' starts a Python comment, so the rest of the line is prose (which may +// contain unbalanced parens or apostrophes like "tab's") and must be ignored — otherwise +// a comment can corrupt depth tracking and glue unrelated statements together. +func countParenDepthChange(s string) int { + depth := 0 + inSingle, inDouble := false, false + for i := 0; i < len(s); i++ { + c := s[i] + if c == '\\' && (inSingle || inDouble) { + i++ // skip escaped char + continue + } + switch { + case c == '#' && !inSingle && !inDouble: + return depth // rest of the line is a comment + case c == '\'' && !inDouble: + inSingle = !inSingle + case c == '"' && !inSingle: + inDouble = !inDouble + case c == '(' && !inSingle && !inDouble: + depth++ + case c == ')' && !inSingle && !inDouble: + depth-- + } + } + return depth +} + +// ---- call matching -------------------------------------------------------- + +// callRe matches a function/method call: captures the call expression and args +// between the outermost parens. The function name may include a dotted prefix. +var callRe = regexp.MustCompile(`((?:[A-Za-z_][A-Za-z0-9_.]*\.)?[A-Za-z_][A-Za-z0-9_]*)\s*\(([^)]*)\)`) + +func matchLogicalLine(line string, startLine int, filename string, constTable map[string]string) (discovered []DiscoveredModel, unresolved []UnresolvedSite) { + // Strip leading whitespace and inline comments for cleaner matching. + trimmed := strings.TrimLeft(line, " \t") + if strings.HasPrefix(trimmed, "#") { + return + } + + for _, m := range callRe.FindAllStringSubmatch(line, -1) { + fullName := m[1] + argsRaw := m[2] + name := callName(fullName) + + for _, cp := range knownCalls { + if !cp.namePattern.MatchString(name) { + continue + } + loc := Location{File: filename, Line: startLine} + d, u := resolveCall(cp, argsRaw, loc, constTable) + discovered = append(discovered, d...) + unresolved = append(unresolved, u...) + } + } + return +} + +// ---- argument resolution -------------------------------------------------- + +// resolveCall parses the raw argument string for a matched call and produces +// DiscoveredModel or UnresolvedSite entries. +func resolveCall(cp callPattern, argsRaw string, loc Location, constTable map[string]string) (discovered []DiscoveredModel, unresolved []UnresolvedSite) { + positional, keyword := parseArgs(argsRaw) + + // Resolve repo_id + var repoIDArg argValue + if cp.repoIDArgName != "" { + if v, ok := keyword[cp.repoIDArgName]; ok { + repoIDArg = resolveArgValue(v, constTable) + } else if len(positional) > 0 { + repoIDArg = resolveArgValue(positional[0], constTable) + } + } else if len(positional) > 0 { + repoIDArg = resolveArgValue(positional[0], constTable) + } + + if repoIDArg.literal == "" { + // Cannot determine repo_id — record as unresolved. + reason := "non-literal repo_id" + if repoIDArg.isFString { + reason = "f-string repo_id" + } + snippet := buildSnippet(loc.File, argsRaw) + unresolved = append(unresolved, UnresolvedSite{Location: loc, Snippet: snippet, Reason: reason}) + return + } + + // Resolve revision + var revision string + revDefaulted := false + revDynamic := false + if v, ok := keyword[cp.revisionArgName]; ok { + rv := resolveArgValue(v, constTable) + if rv.literal != "" { + revision = rv.literal + } else { + revDynamic = true + revision = DefaultRevision + } + } else { + revision = DefaultRevision + revDefaulted = true + } + + // Resolve repo_type + repoType := cp.defaultRepoType + if cp.repoTypeArgName != "" { + if v, ok := keyword[cp.repoTypeArgName]; ok { + rv := resolveArgValue(v, constTable) + if rv.literal == "dataset" { + repoType = RepoTypeDataset + } + } + } + + discovered = append(discovered, DiscoveredModel{ + RepoID: repoIDArg.literal, + Revision: revision, + RevisionDefaulted: revDefaulted, + RevisionDynamic: revDynamic, + RepoType: repoType, + Sources: []Location{loc}, + }) + return +} + +// parseArgs splits a raw argument string into positional and keyword slices. +// It handles simple cases well; deeply nested calls may produce imperfect +// splits, but those will fail the literal-string check and become UnresolvedSite. +func parseArgs(raw string) (positional []string, keyword map[string]string) { + keyword = map[string]string{} + if strings.TrimSpace(raw) == "" { + return + } + parts := splitArgs(raw) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + // keyword arg: name = value + if idx := strings.Index(part, "="); idx > 0 { + k := strings.TrimSpace(part[:idx]) + v := strings.TrimSpace(part[idx+1:]) + // Only treat as keyword if k is a valid identifier + if isIdentifier(k) { + keyword[k] = v + continue + } + } + positional = append(positional, part) + } + return +} + +// splitArgs splits a comma-separated argument list respecting quoted strings +// and nested parentheses. +func splitArgs(s string) []string { + var parts []string + depth := 0 + inSingle, inDouble := false, false + start := 0 + for i := 0; i < len(s); i++ { + c := s[i] + if c == '\\' && (inSingle || inDouble) { + i++ + continue + } + switch { + case c == '\'' && !inDouble: + inSingle = !inSingle + case c == '"' && !inSingle: + inDouble = !inDouble + case (c == '(' || c == '[' || c == '{') && !inSingle && !inDouble: + depth++ + case (c == ')' || c == ']' || c == '}') && !inSingle && !inDouble: + depth-- + case c == ',' && depth == 0 && !inSingle && !inDouble: + parts = append(parts, s[start:i]) + start = i + 1 + } + } + parts = append(parts, s[start:]) + return parts +} + +// resolveArgValue classifies a raw argument token. +func resolveArgValue(raw string, constTable map[string]string) argValue { + s := strings.TrimSpace(raw) + + // String literal: "value" or 'value' + if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`) && len(s) >= 2) || + (strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`) && len(s) >= 2) { + return argValue{literal: s[1 : len(s)-1]} + } + + // f-string → dynamic + if strings.HasPrefix(s, `f"`) || strings.HasPrefix(s, `f'`) || + strings.HasPrefix(s, `F"`) || strings.HasPrefix(s, `F'`) { + return argValue{isDynamic: true, isFString: true} + } + + // Simple identifier → check constant table + if isIdentifier(s) { + if v, ok := constTable[s]; ok { + return argValue{literal: v} + } + return argValue{isDynamic: true} + } + + return argValue{isDynamic: true} +} + +// isIdentifier returns true if s is a valid Python/Go identifier token. +var identRe = regexp.MustCompile(`^[A-Za-z_][A-Za-z0-9_]*$`) + +func isIdentifier(s string) bool { + return identRe.MatchString(s) +} + +func buildSnippet(filename, argsRaw string) string { + if len(argsRaw) > 60 { + return argsRaw[:60] + "..." + } + return argsRaw +} diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/python_test.go b/sca/bom/buildinfo/technologies/huggingface/discovery/python_test.go new file mode 100644 index 000000000..9234ed559 --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/python_test.go @@ -0,0 +1,159 @@ +package discovery + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParsePythonSource_Literals(t *testing.T) { + src := ` +from transformers import AutoModel +model = AutoModel.from_pretrained("mcpotato/42-eicar-street", revision="main") +` + disc, unres := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "mcpotato/42-eicar-street", disc[0].RepoID) + assert.Equal(t, "main", disc[0].Revision) + assert.False(t, disc[0].RevisionDefaulted) + assert.Equal(t, RepoTypeModel, disc[0].RepoType) + assert.Empty(t, unres) +} + +func TestParsePythonSource_RevisionDefaulted(t *testing.T) { + src := `snapshot_download(repo_id="bert-base-uncased")` + disc, unres := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "bert-base-uncased", disc[0].RepoID) + assert.Equal(t, DefaultRevision, disc[0].Revision) + assert.True(t, disc[0].RevisionDefaulted) + assert.Empty(t, unres) +} + +func TestParsePythonSource_Dataset(t *testing.T) { + src := `from datasets import load_dataset +ds = load_dataset("squad", revision="1.0.0")` + disc, unres := ParsePythonSource(src, "train.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "squad", disc[0].RepoID) + assert.Equal(t, "1.0.0", disc[0].Revision) + assert.Equal(t, RepoTypeDataset, disc[0].RepoType) + assert.Empty(t, unres) +} + +func TestParsePythonSource_SnapshotDownloadWithRepoType(t *testing.T) { + src := `snapshot_download(repo_id="org/ds", revision="v2", repo_type="dataset")` + disc, unres := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, RepoTypeDataset, disc[0].RepoType) + assert.Equal(t, "v2", disc[0].Revision) + assert.Empty(t, unres) +} + +// TestParsePythonSource_ApostropheInComment guards against a regression where an +// apostrophe inside a '#' comment (e.g. "tab's") was treated as an open string +// literal in countParenDepthChange, leaving paren depth stuck and gluing every +// following statement onto the comment line — which silently dropped real calls. +func TestParsePythonSource_ApostropheInComment(t *testing.T) { + src := `from huggingface_hub import snapshot_download + +# Whole-repo download (matches the "Resolve" tab's snapshot_download example). +LLAMA = snapshot_download(repo_id="meta-llama/Llama-2-7b-hf", revision="main") + +# A reference we EXPECT curation to block (it's malicious). +UNSAFE = snapshot_download(repo_id="mcpotato/42-eicar-street", revision="8fb61c4d511e9aaff0ea55396a124aa292830efc") +` + disc, unres := ParsePythonSource(src, "app.py", nil) + require.Len(t, disc, 2) + assert.Equal(t, "meta-llama/Llama-2-7b-hf", disc[0].RepoID) + assert.Equal(t, "main", disc[0].Revision) + assert.Equal(t, "mcpotato/42-eicar-street", disc[1].RepoID) + assert.Equal(t, "8fb61c4d511e9aaff0ea55396a124aa292830efc", disc[1].Revision) + assert.Empty(t, unres) +} + +func TestParsePythonSource_ConstantTable(t *testing.T) { + src := ` +MODEL_ID = "org/my-model" +from transformers import AutoTokenizer +tok = AutoTokenizer.from_pretrained(MODEL_ID, revision="abc123") +` + disc, unres := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "org/my-model", disc[0].RepoID) + assert.Equal(t, "abc123", disc[0].Revision) + assert.Empty(t, unres) +} + +func TestParsePythonSource_DynamicRepoID(t *testing.T) { + src := `from transformers import AutoModel +model = AutoModel.from_pretrained(args.model_name)` + disc, unres := ParsePythonSource(src, "trainer.py", nil) + assert.Empty(t, disc) + require.Len(t, unres, 1) + assert.Contains(t, unres[0].Reason, "non-literal") + assert.Equal(t, "trainer.py", unres[0].Location.File) +} + +func TestParsePythonSource_FStringRepoID(t *testing.T) { + src := `from_pretrained(f"{ORG}/{name}")` + disc, unres := ParsePythonSource(src, "test.py", nil) + assert.Empty(t, disc) + require.Len(t, unres, 1) + assert.Equal(t, "f-string repo_id", unres[0].Reason) +} + +func TestParsePythonSource_DynamicRevision(t *testing.T) { + src := `snapshot_download(repo_id="org/model", revision=args.rev)` + disc, unres := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "org/model", disc[0].RepoID) + assert.True(t, disc[0].RevisionDynamic) + assert.Equal(t, DefaultRevision, disc[0].Revision) // falls back to main + assert.Empty(t, unres) +} + +func TestParsePythonSource_HfHubDownload(t *testing.T) { + src := `hf_hub_download(repo_id="org/model", filename="model.bin", revision="v1")` + disc, unres := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "org/model", disc[0].RepoID) + assert.Equal(t, "v1", disc[0].Revision) + assert.Empty(t, unres) +} + +func TestParsePythonSource_MultipleModels(t *testing.T) { + src := ` +from transformers import AutoModel, AutoTokenizer +model = AutoModel.from_pretrained("org/model-a", revision="v1") +tok = AutoTokenizer.from_pretrained("org/model-b") +` + disc, unres := ParsePythonSource(src, "test.py", nil) + assert.Len(t, disc, 2) + assert.Empty(t, unres) +} + +func TestParsePythonSource_CommentsIgnored(t *testing.T) { + src := `# from_pretrained("should-not-match") +model = AutoModel.from_pretrained("org/real-model")` + disc, _ := ParsePythonSource(src, "test.py", nil) + require.Len(t, disc, 1) + assert.Equal(t, "org/real-model", disc[0].RepoID) +} + +func TestBuildConstTable_Reassignment(t *testing.T) { + src := ` +MODEL = "first-value" +MODEL = "second-value" +` + table := buildConstTable(src) + _, exists := table["MODEL"] + assert.False(t, exists, "reassigned name should be removed from constant table") +} + +func TestBuildConstTable_SingleAssignment(t *testing.T) { + src := `BASE = "org/model"` + table := buildConstTable(src) + assert.Equal(t, "org/model", table["BASE"]) +} diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/result.go b/sca/bom/buildinfo/technologies/huggingface/discovery/result.go new file mode 100644 index 000000000..caa658721 --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/result.go @@ -0,0 +1,57 @@ +package discovery + +// DefaultRevision is the fallback revision when none is pinned in source, +// mirroring the Hugging Face client's default behaviour. +const DefaultRevision = "main" + +// RepoType distinguishes HF model repos from dataset repos. +// The probe URL path differs: api/models/ vs api/datasets/. +type RepoType string + +const ( + RepoTypeModel RepoType = "model" + RepoTypeDataset RepoType = "dataset" +) + +// Location identifies a specific line in a source file. +// For Jupyter notebooks the File is "notebook.ipynb#cell-" and Line is +// relative to the start of that cell. +type Location struct { + File string + Line int // 1-based +} + +// DiscoveredModel is a fully or partially resolved HF reference extracted from source. +type DiscoveredModel struct { + RepoID string + // Revision is the pinned branch/tag/sha, or DefaultRevision when absent in source. + Revision string + // RevisionDefaulted is true when no revision was present in the call — + // the audit targets whatever commit the branch currently points to. + RevisionDefaulted bool + // RevisionDynamic is true when a revision arg was present but non-literal. + // The model is still audited against DefaultRevision with a warning. + RevisionDynamic bool + RepoType RepoType + // Sources lists every call site that produced this reference (after dedup). + Sources []Location +} + +// UnresolvedSite records a call site whose repo_id could not be statically resolved. +// These are NOT audited; they are surfaced in the warning block so the user can +// pass them explicitly via --hugging-face-model. +type UnresolvedSite struct { + Location Location + Snippet string + // Reason is one of "non-literal repo_id", "f-string repo_id", "dynamic repo_id". + Reason string +} + +// ScanResult is the output of a full directory scan. +type ScanResult struct { + // Discovered holds deduplicated (repo_type, repo_id, revision) tuples ready + // to hand to the curation walker. + Discovered []DiscoveredModel + // Unresolved holds call sites that could not be statically resolved. + Unresolved []UnresolvedSite +} diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/scanner.go b/sca/bom/buildinfo/technologies/huggingface/discovery/scanner.go new file mode 100644 index 000000000..c1b3c6ef2 --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/scanner.go @@ -0,0 +1,155 @@ +package discovery + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/jfrog/jfrog-client-go/utils/log" +) + +// defaultExcludeDirs are directory names skipped during the walk. +// These mirror the exclusion list used by other jf ca BOM builders. +var defaultExcludeDirs = map[string]struct{}{ + ".git": {}, + ".hg": {}, + "node_modules": {}, + "__pycache__": {}, + ".venv": {}, + "venv": {}, + "env": {}, + ".env": {}, + "site-packages": {}, + ".tox": {}, + "dist": {}, + "build": {}, + ".eggs": {}, +} + +// ScanDir walks root recursively and discovers all Hugging Face model/dataset +// references in *.py files and *.ipynb notebooks. +// It returns a deduplicated ScanResult with the warning block pre-formatted. +func ScanDir(root string) (*ScanResult, error) { + var allDiscovered []DiscoveredModel + var allUnresolved []UnresolvedSite + + err := filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + log.Debug(fmt.Sprintf("huggingface scanner: skipping %s: %v", path, err)) + return nil + } + if d.IsDir() { + if _, skip := defaultExcludeDirs[d.Name()]; skip { + return filepath.SkipDir + } + return nil + } + switch strings.ToLower(filepath.Ext(path)) { + case ".py": + disc, unres, ferr := scanPyFile(path, root) + if ferr != nil { + log.Debug(fmt.Sprintf("huggingface scanner: skipping %s: %v", path, ferr)) + return nil + } + allDiscovered = append(allDiscovered, disc...) + allUnresolved = append(allUnresolved, unres...) + case ".ipynb": + disc, unres, ferr := ParseNotebook(path) + if ferr != nil { + log.Debug(fmt.Sprintf("huggingface scanner: skipping notebook %s: %v", path, ferr)) + return nil + } + // Relativise notebook path for display. + rel, relErr := filepath.Rel(root, path) + if relErr != nil { + rel = path + } + for i := range disc { + for j := range disc[i].Sources { + disc[i].Sources[j].File = rebaseNotebook(disc[i].Sources[j].File, path, rel) + } + } + for i := range unres { + unres[i].Location.File = rebaseNotebook(unres[i].Location.File, path, rel) + } + allDiscovered = append(allDiscovered, disc...) + allUnresolved = append(allUnresolved, unres...) + } + return nil + }) + if err != nil { + return nil, fmt.Errorf("huggingface scanner: walking %s: %w", root, err) + } + + return &ScanResult{ + Discovered: dedup(allDiscovered), + Unresolved: allUnresolved, + }, nil +} + +// FormatWarnings returns the consolidated warning block for unresolved sites, +// or an empty string when there are none. +func FormatWarnings(unresolved []UnresolvedSite) string { + if len(unresolved) == 0 { + return "" + } + var sb strings.Builder + fmt.Fprintf(&sb, "WARN: %d Hugging Face reference(s) could not be statically resolved and were NOT audited:\n", len(unresolved)) + for _, u := range unresolved { + fmt.Fprintf(&sb, " %s:%d\t%s\t— %s\n", u.Location.File, u.Location.Line, u.Snippet, u.Reason) + } + sb.WriteString("These references resolve their model id or revision at runtime, so they cannot be checked statically.\n") + sb.WriteString("To audit them, re-run with --hugging-face-model=: (comma-separate multiple models; pin the revision you ship).") + return sb.String() +} + +// ---- internal helpers ----------------------------------------------------- + +func scanPyFile(path, root string) ([]DiscoveredModel, []UnresolvedSite, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, nil, err + } + rel, err := filepath.Rel(root, path) + if err != nil { + rel = path + } + disc, unres := ParsePythonSource(string(data), rel, nil) + return disc, unres, nil +} + +// rebaseNotebook replaces the absolute notebook path prefix with the relative one. +func rebaseNotebook(location, absPath, relPath string) string { + return strings.Replace(location, absPath, relPath, 1) +} + +// dedup collapses DiscoveredModel entries with the same (repo_type, repo_id, revision) +// into one, merging their Sources lists. +func dedup(models []DiscoveredModel) []DiscoveredModel { + type key struct { + repoType RepoType + repoID string + revision string + } + index := map[key]int{} + var result []DiscoveredModel + for _, m := range models { + k := key{m.RepoType, m.RepoID, m.Revision} + if idx, exists := index[k]; exists { + result[idx].Sources = append(result[idx].Sources, m.Sources...) + // Propagate flags: if any reference lacked a revision, flag it. + if m.RevisionDefaulted { + result[idx].RevisionDefaulted = true + } + if m.RevisionDynamic { + result[idx].RevisionDynamic = true + } + } else { + index[k] = len(result) + result = append(result, m) + } + } + return result +} diff --git a/sca/bom/buildinfo/technologies/huggingface/discovery/scanner_test.go b/sca/bom/buildinfo/technologies/huggingface/discovery/scanner_test.go new file mode 100644 index 000000000..c53c065c2 --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/discovery/scanner_test.go @@ -0,0 +1,93 @@ +package discovery + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestScanDir_BasicDiscovery(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "app.py", ` +from transformers import AutoModel +model = AutoModel.from_pretrained("org/model-a", revision="v1") +`) + result, err := ScanDir(dir) + require.NoError(t, err) + require.Len(t, result.Discovered, 1) + assert.Equal(t, "org/model-a", result.Discovered[0].RepoID) + assert.Empty(t, result.Unresolved) +} + +func TestScanDir_Dedup(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "a.py", `from_pretrained("org/model", revision="v1")`) + writeFile(t, dir, "b.py", `from_pretrained("org/model", revision="v1")`) + result, err := ScanDir(dir) + require.NoError(t, err) + assert.Len(t, result.Discovered, 1, "same model in two files should deduplicate") + assert.Len(t, result.Discovered[0].Sources, 2, "both source locations should be recorded") +} + +func TestScanDir_ExcludeVenv(t *testing.T) { + dir := t.TempDir() + venv := filepath.Join(dir, ".venv") + require.NoError(t, os.MkdirAll(venv, 0755)) + writeFile(t, venv, "excluded.py", `from_pretrained("org/should-not-appear")`) + writeFile(t, dir, "real.py", `from_pretrained("org/real-model")`) + result, err := ScanDir(dir) + require.NoError(t, err) + require.Len(t, result.Discovered, 1) + assert.Equal(t, "org/real-model", result.Discovered[0].RepoID) +} + +func TestScanDir_UnresolvedWarning(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "dynamic.py", `from_pretrained(args.model_name)`) + result, err := ScanDir(dir) + require.NoError(t, err) + assert.Empty(t, result.Discovered) + require.Len(t, result.Unresolved, 1) + + warn := FormatWarnings(result.Unresolved) + assert.Contains(t, warn, "WARN:") + assert.Contains(t, warn, "--hugging-face-model") +} + +func TestScanDir_MixedResolved(t *testing.T) { + dir := t.TempDir() + writeFile(t, dir, "mixed.py", ` +from_pretrained("org/good-model", revision="v1") +from_pretrained(config.model_id) +snapshot_download(repo_id="org/dataset", repo_type="dataset") +`) + result, err := ScanDir(dir) + require.NoError(t, err) + assert.Len(t, result.Discovered, 2) + assert.Len(t, result.Unresolved, 1) +} + +func TestFormatWarnings_Empty(t *testing.T) { + assert.Equal(t, "", FormatWarnings(nil)) + assert.Equal(t, "", FormatWarnings([]UnresolvedSite{})) +} + +func TestFormatWarnings_Content(t *testing.T) { + sites := []UnresolvedSite{ + {Location: Location{File: "trainer.py", Line: 42}, Snippet: "from_pretrained(args.m)", Reason: "non-literal repo_id"}, + } + out := FormatWarnings(sites) + assert.True(t, strings.HasPrefix(out, "WARN:")) + assert.Contains(t, out, "trainer.py:42") + assert.Contains(t, out, "non-literal repo_id") + assert.Contains(t, out, "--hugging-face-model") +} + +func writeFile(t *testing.T, dir, name, content string) { + t.Helper() + require.NoError(t, os.WriteFile(filepath.Join(dir, name), []byte(content), 0644)) +} diff --git a/sca/bom/buildinfo/technologies/huggingface/huggingface.go b/sca/bom/buildinfo/technologies/huggingface/huggingface.go new file mode 100644 index 000000000..33b142deb --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/huggingface.go @@ -0,0 +1,229 @@ +package huggingface + +import ( + "fmt" + "os" + "strings" + + "github.com/jfrog/jfrog-cli-core/v2/common/project" + "github.com/jfrog/jfrog-cli-core/v2/utils/config" + "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies" + "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies/huggingface/discovery" + "github.com/jfrog/jfrog-cli-security/utils/artifactory" + "github.com/jfrog/jfrog-client-go/utils/log" + xrayUtils "github.com/jfrog/jfrog-client-go/xray/services/utils" +) + +// HuggingFacePackagePrefix is the node-id prefix used for Hugging Face model/dataset references. +const HuggingFacePackagePrefix = "huggingfaceml://" + +// DatasetNodeMarker is inserted immediately after HuggingFacePackagePrefix in a node id +// to mark the reference as a dataset. Datasets are probed via the api/datasets/ endpoint +// rather than api/models/. The '|' separator cannot appear in a HF repo id or revision, +// so it round-trips cleanly through getHuggingFaceNameAndVersion. +const DatasetNodeMarker = "dataset|" + +// DefaultRevision is used when the model reference does not pin an explicit revision. +const DefaultRevision = "main" + +// hfEndpointEnv is the env var the Hugging Face client uses to point at the Artifactory +// proxy, e.g. "https://my.jfrog.io/artifactory/api/huggingfaceml/my-hugging-face-repo". +const hfEndpointEnv = "HF_ENDPOINT" + +// hfEndpointRepoMarker precedes the Artifactory repository name in HF_ENDPOINT. +const hfEndpointRepoMarker = "api/huggingfaceml/" + +// ModelInfo holds the parsed components of a --hugging-face-model reference. +// +// The flag value is the Hugging Face model/dataset id with an optional revision: +// "/[:]" e.g. "mcpotato/42-eicar-street:main". +// - RepoId = "mcpotato/42-eicar-street" (the Hugging Face model/dataset id) +// - Revision = "main" (branch, tag, or 40-char commit sha; defaults to "main") +// +// The Artifactory repository is NOT part of the flag — it is read from the HF_ENDPOINT +// environment variable (the same one the HF client uses to resolve through the proxy). +type ModelInfo struct { + RepoId string + Revision string +} + +// ParseModelReference parses a --hugging-face-model value ("/[:]") +// into the model id and revision. The revision defaults to "main" when not pinned, +// mirroring the Hugging Face client's snapshot_download(..., revision="main"). +func ParseModelReference(modelRef string) (*ModelInfo, error) { + modelRef = strings.TrimSpace(strings.TrimPrefix(modelRef, HuggingFacePackagePrefix)) + if modelRef == "" { + return nil, fmt.Errorf("hugging face model reference is empty") + } + + info := &ModelInfo{Revision: DefaultRevision} + + // Split off the optional revision. A revision never contains '/', so only treat a + // trailing ":" as a revision when the suffix has no path separator. + if idx := strings.LastIndex(modelRef, ":"); idx > 0 && !strings.Contains(modelRef[idx+1:], "/") { + info.Revision = modelRef[idx+1:] + modelRef = modelRef[:idx] + } + + if modelRef == "" { + return nil, fmt.Errorf("invalid hugging face model reference: expected '/[:]'") + } + info.RepoId = modelRef + + log.Debug(fmt.Sprintf("Parsed Hugging Face model - RepoId: %s, Revision: %s", info.RepoId, info.Revision)) + return info, nil +} + +// ParseModelReferences parses a comma-separated list of --hugging-face-model values +// (":,:,...") into individual ModelInfo entries. +// Whitespace around each entry is trimmed and empty entries are skipped, so trailing +// commas and accidental spaces are tolerated. +func ParseModelReferences(modelRefs string) ([]*ModelInfo, error) { + var infos []*ModelInfo + for _, ref := range strings.Split(modelRefs, ",") { + ref = strings.TrimSpace(ref) + if ref == "" { + continue + } + info, err := ParseModelReference(ref) + if err != nil { + return nil, err + } + infos = append(infos, info) + } + if len(infos) == 0 { + return nil, fmt.Errorf("hugging face model reference is empty") + } + return infos, nil +} + +// repoFromHFEndpoint extracts the Artifactory repository name from the HF_ENDPOINT env var. +// HF_ENDPOINT looks like ".../artifactory/api/huggingfaceml/"; we return "". +func repoFromHFEndpoint() (string, error) { + endpoint := strings.TrimSpace(os.Getenv(hfEndpointEnv)) + if endpoint == "" { + return "", fmt.Errorf("%s is not set. Export it to your Artifactory Hugging Face repository, e.g. '%s=https:///artifactory/%s'", + hfEndpointEnv, hfEndpointEnv, hfEndpointRepoMarker) + } + idx := strings.Index(endpoint, hfEndpointRepoMarker) + if idx < 0 { + return "", fmt.Errorf("%s ('%s') does not contain '%s'; cannot determine the Artifactory repository", hfEndpointEnv, endpoint, hfEndpointRepoMarker) + } + // The repository is the first path segment after the marker (ignore any trailing path/query). + repo := strings.Trim(endpoint[idx+len(hfEndpointRepoMarker):], "/") + if before, _, found := strings.Cut(repo, "/"); found { + repo = before + } + if before, _, found := strings.Cut(repo, "?"); found { + repo = before + } + if repo == "" { + return "", fmt.Errorf("%s ('%s') has no repository segment after '%s'", hfEndpointEnv, endpoint, hfEndpointRepoMarker) + } + return repo, nil +} + +// BuildDependencyTree builds the dependency graph for Hugging Face models/datasets. +// +// Two modes: +// 1. Flag mode (--hugging-face-model set): pure spot-check — audits only the +// explicitly named models (comma-separated), skips source scanning entirely. +// Fast and unambiguous: you get exactly what you asked for. +// 2. Auto-discovery mode (no flag): scans Python source and notebooks in the +// working directory for from_pretrained / snapshot_download / load_dataset / +// hf_hub_download call sites. Unresolved (dynamic) call sites are returned as +// warnings for the caller to surface after the curation tables. +func BuildDependencyTree(params technologies.BuildInfoBomGeneratorParams) (trees []*xrayUtils.GraphNode, uniqueIDs []string, warnings []string, err error) { + workingDir := params.WorkingDirectory + if workingDir == "" { + workingDir = "." + } + + var children []*xrayUtils.GraphNode + seen := map[string]bool{} + add := func(nodeID string) { + if seen[nodeID] { + return + } + seen[nodeID] = true + children = append(children, &xrayUtils.GraphNode{Id: nodeID}) + uniqueIDs = append(uniqueIDs, nodeID) + } + + // 1) Explicit models from the --hugging-face-model flag (comma-separated). + // Flag mode is a pure spot-check: only the named models are audited, the source + // scanner is skipped entirely so the result is fast and unambiguous. + if params.HuggingFaceModel != "" { + models, perr := ParseModelReferences(params.HuggingFaceModel) + if perr != nil { + return nil, nil, nil, perr + } + for _, m := range models { + add(HuggingFacePackagePrefix + m.RepoId + ":" + m.Revision) + } + if len(children) == 0 { + return nil, nil, nil, nil + } + root := &xrayUtils.GraphNode{Id: "huggingface-project", Nodes: children} + return []*xrayUtils.GraphNode{root}, uniqueIDs, nil, nil + } + + // 2) Auto-discovery mode: scan Python source / notebooks in the working dir. + log.Debug(fmt.Sprintf("Hugging Face: scanning %s for model references", workingDir)) + result, serr := discovery.ScanDir(workingDir) + if serr != nil { + return nil, nil, nil, fmt.Errorf("hugging face source scan failed: %w", serr) + } + if warn := discovery.FormatWarnings(result.Unresolved); warn != "" { + warnings = append(warnings, warn) + } + for _, m := range result.Discovered { + nodeID := HuggingFacePackagePrefix + m.RepoID + ":" + m.Revision + if m.RepoType == discovery.RepoTypeDataset { + nodeID = HuggingFacePackagePrefix + DatasetNodeMarker + m.RepoID + ":" + m.Revision + } + if m.RevisionDefaulted { + log.Info(fmt.Sprintf("Hugging Face: %s has no pinned revision — auditing against current HEAD of '%s'", m.RepoID, m.Revision)) + } + if m.RevisionDynamic { + warnings = append(warnings, fmt.Sprintf("Hugging Face: %s has a dynamic revision in source — audited against '%s' (may not match the revision resolved at runtime)", m.RepoID, m.Revision)) + } + add(nodeID) + } + + if len(children) == 0 { + log.Debug("Hugging Face: no model references found (flag or source)") + return nil, nil, warnings, nil + } + + // Flat graph: one root node whose children are the unique model/dataset refs. + root := &xrayUtils.GraphNode{Id: "huggingface-project", Nodes: children} + return []*xrayUtils.GraphNode{root}, uniqueIDs, warnings, nil +} + +// GetHuggingFaceRepositoryConfig resolves the Artifactory repository from HF_ENDPOINT +// and verifies it exists, mirroring docker.GetDockerRepositoryConfig. +func GetHuggingFaceRepositoryConfig() (*project.RepositoryConfig, error) { + serverDetails, err := config.GetDefaultServerConf() + if err != nil { + return nil, err + } + if serverDetails == nil { + return nil, fmt.Errorf("no Artifactory server configured. Use 'jf c add' to configure a server") + } + repo, err := repoFromHFEndpoint() + if err != nil { + return nil, err + } + exists, err := artifactory.IsRepoExists(repo, serverDetails) + if err != nil { + return nil, fmt.Errorf("failed to check if repository '%s' exists on Artifactory '%s': %w", repo, serverDetails.Url, err) + } + if !exists { + return nil, fmt.Errorf("repository '%s' (from %s) was not found on Artifactory (%s), ensure the repository exists", repo, hfEndpointEnv, serverDetails.Url) + } + + repoConfig := &project.RepositoryConfig{} + repoConfig.SetServerDetails(serverDetails).SetTargetRepo(repo) + return repoConfig, nil +} diff --git a/sca/bom/buildinfo/technologies/huggingface/huggingface_test.go b/sca/bom/buildinfo/technologies/huggingface/huggingface_test.go new file mode 100644 index 000000000..ece00510e --- /dev/null +++ b/sca/bom/buildinfo/technologies/huggingface/huggingface_test.go @@ -0,0 +1,243 @@ +package huggingface + +import ( + "os" + "testing" + + "github.com/jfrog/jfrog-cli-security/sca/bom/buildinfo/technologies" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseModelReference(t *testing.T) { + tests := []struct { + name string + input string + wantRepoId string + wantRevision string + wantErr string + }{ + { + name: "model id with sha revision", + input: "mcpotato/42-eicar-street:8fb61c4d511e9aaff0ea55396a124aa292830efc", + wantRepoId: "mcpotato/42-eicar-street", + wantRevision: "8fb61c4d511e9aaff0ea55396a124aa292830efc", + }, + { + name: "model id with branch revision", + input: "mcpotato/42-eicar-street:main", + wantRepoId: "mcpotato/42-eicar-street", + wantRevision: "main", + }, + { + name: "no revision defaults to main", + input: "org/model", + wantRepoId: "org/model", + wantRevision: DefaultRevision, + }, + { + name: "single-segment model id with revision", + input: "bert-base-uncased:v1.0", + wantRepoId: "bert-base-uncased", + wantRevision: "v1.0", + }, + { + name: "huggingfaceml:// prefix stripped", + input: HuggingFacePackagePrefix + "org/model:v2", + wantRepoId: "org/model", + wantRevision: "v2", + }, + { + name: "empty input", + input: "", + wantErr: "hugging face model reference is empty", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info, err := ParseModelReference(tt.input) + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantRepoId, info.RepoId, "RepoId mismatch") + assert.Equal(t, tt.wantRevision, info.Revision, "Revision mismatch") + }) + } +} + +func TestParseModelReferences(t *testing.T) { + t.Run("comma-separated with whitespace and trailing comma", func(t *testing.T) { + infos, err := ParseModelReferences(" org/a:main , org/b:v2 ,") + require.NoError(t, err) + require.Len(t, infos, 2) + assert.Equal(t, "org/a", infos[0].RepoId) + assert.Equal(t, "main", infos[0].Revision) + assert.Equal(t, "org/b", infos[1].RepoId) + assert.Equal(t, "v2", infos[1].Revision) + }) + t.Run("single value", func(t *testing.T) { + infos, err := ParseModelReferences("org/only") + require.NoError(t, err) + require.Len(t, infos, 1) + assert.Equal(t, "org/only", infos[0].RepoId) + assert.Equal(t, DefaultRevision, infos[0].Revision) + }) + t.Run("all empty entries error", func(t *testing.T) { + _, err := ParseModelReferences(" , , ") + require.Error(t, err) + assert.Contains(t, err.Error(), "empty") + }) +} + +func TestRepoFromHFEndpoint(t *testing.T) { + tests := []struct { + name string + endpoint string + wantRepo string + wantErr string + }{ + { + name: "standard endpoint", + endpoint: "https://z0gytst.jfrogdev.org/artifactory/api/huggingfaceml/my-hugging-face-repo", + wantRepo: "my-hugging-face-repo", + }, + { + name: "endpoint with trailing slash", + endpoint: "https://my.jfrog.io/artifactory/api/huggingfaceml/hf-repo/", + wantRepo: "hf-repo", + }, + { + name: "endpoint with extra path after repo", + endpoint: "https://my.jfrog.io/artifactory/api/huggingfaceml/hf-repo/api/models", + wantRepo: "hf-repo", + }, + { + name: "not set", + endpoint: "", + wantErr: "HF_ENDPOINT is not set", + }, + { + name: "missing marker", + endpoint: "https://my.jfrog.io/artifactory/api/npm/npm-repo", + wantErr: "does not contain", + }, + { + name: "no repo segment after marker", + endpoint: "https://my.jfrog.io/artifactory/api/huggingfaceml/", + wantErr: "no repository segment", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv(hfEndpointEnv, tt.endpoint) + repo, err := repoFromHFEndpoint() + if tt.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantRepo, repo) + }) + } +} + +func TestBuildDependencyTree(t *testing.T) { + tests := []struct { + name string + modelRef string + wantLeafId string + wantUniqueDep string + }{ + { + name: "blocked malicious model with sha", + modelRef: "mcpotato/42-eicar-street:8fb61c4d511e9aaff0ea55396a124aa292830efc", + wantLeafId: HuggingFacePackagePrefix + "mcpotato/42-eicar-street:8fb61c4d511e9aaff0ea55396a124aa292830efc", + wantUniqueDep: HuggingFacePackagePrefix + "mcpotato/42-eicar-street:8fb61c4d511e9aaff0ea55396a124aa292830efc", + }, + { + name: "model with branch revision", + modelRef: "bert-base-uncased:main", + wantLeafId: HuggingFacePackagePrefix + "bert-base-uncased:main", + wantUniqueDep: HuggingFacePackagePrefix + "bert-base-uncased:main", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Empty working dir so only the flag contributes (no stray discovery). + params := technologies.BuildInfoBomGeneratorParams{HuggingFaceModel: tt.modelRef, WorkingDirectory: t.TempDir()} + trees, uniqueDeps, warnings, err := BuildDependencyTree(params) + require.NoError(t, err) + require.Len(t, trees, 1, "expected exactly one dependency tree") + assert.Equal(t, "huggingface-project", trees[0].Id, "root node id mismatch") + require.Len(t, trees[0].Nodes, 1, "expected exactly one child node") + assert.Equal(t, tt.wantLeafId, trees[0].Nodes[0].Id, "leaf node id mismatch") + require.Len(t, uniqueDeps, 1, "expected exactly one unique dep") + assert.Equal(t, tt.wantUniqueDep, uniqueDeps[0], "unique dep mismatch") + assert.Empty(t, warnings, "flag-only mode should not produce warnings") + }) + } +} + +// TestBuildDependencyTree_MultiValueAndAdditive verifies that the flag accepts a +// comma-separated list and that flag mode audits exactly the named models — no more, +// no less (source scan is skipped in flag mode). +func TestBuildDependencyTree_MultiValueAndAdditive(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile( + dir+"/app.py", []byte(`from_pretrained("org/discovered-model", revision="v1")`), 0644)) + + params := technologies.BuildInfoBomGeneratorParams{ + // Two explicit flag models; source file has a third model that must NOT appear. + HuggingFaceModel: "org/flag-model-a:main, org/flag-model-b:v2", + WorkingDirectory: dir, + } + trees, uniqueDeps, warnings, err := BuildDependencyTree(params) + require.NoError(t, err) + require.Len(t, trees, 1) + // Only the two flag models — discovered-model from source must be absent. + assert.Len(t, uniqueDeps, 2) + assert.Contains(t, uniqueDeps, HuggingFacePackagePrefix+"org/flag-model-a:main") + assert.Contains(t, uniqueDeps, HuggingFacePackagePrefix+"org/flag-model-b:v2") + assert.NotContains(t, uniqueDeps, HuggingFacePackagePrefix+"org/discovered-model:v1") + assert.Empty(t, warnings, "flag mode skips source scan so no warnings expected") +} + +// TestBuildDependencyTree_AutoDiscovery verifies that an empty HuggingFaceModel flag +// triggers source scanning on the working directory instead of returning an error. +func TestBuildDependencyTree_AutoDiscovery(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile( + dir+"/app.py", []byte(`from_pretrained("org/discovered-model", revision="v1")`), 0644)) + + params := technologies.BuildInfoBomGeneratorParams{ + HuggingFaceModel: "", + WorkingDirectory: dir, + } + trees, uniqueDeps, _, err := BuildDependencyTree(params) + require.NoError(t, err) + require.Len(t, trees, 1) + require.Len(t, uniqueDeps, 1) + assert.Contains(t, uniqueDeps[0], "org/discovered-model") +} + +// TestBuildDependencyTree_UnresolvedWarnings verifies that dynamic references are +// returned as warnings (for the caller to surface after the curation table) rather +// than being logged during the BOM-build phase. +func TestBuildDependencyTree_UnresolvedWarnings(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(dir+"/dyn.py", []byte( + "runtime = AutoModel.from_pretrained(args.model_name)\n"), 0644)) + + params := technologies.BuildInfoBomGeneratorParams{WorkingDirectory: dir} + trees, uniqueDeps, warnings, err := BuildDependencyTree(params) + require.NoError(t, err) + assert.Empty(t, trees, "no statically-resolvable models expected") + assert.Empty(t, uniqueDeps) + require.Len(t, warnings, 1, "expected one consolidated unresolved-references warning") + assert.Contains(t, warnings[0], "could not be statically resolved") + assert.Contains(t, warnings[0], "non-literal repo_id") +} diff --git a/utils/techutils/techutils.go b/utils/techutils/techutils.go index 66a553557..142a1b472 100644 --- a/utils/techutils/techutils.go +++ b/utils/techutils/techutils.go @@ -59,14 +59,15 @@ const ( Swift Technology = "swift" Gem Technology = "ruby" // Not Supported by build-info BOM generator - Docker Technology = "docker" - Oci Technology = "oci" - Rpm Technology = "rpm" - Debian Technology = "deb" - Composer Technology = "composer" - Alpine Technology = "alpine" - Cpp Technology = "cpp" - NoTech Technology = "" + Docker Technology = "docker" + HuggingFaceMl Technology = "huggingfaceml" + Oci Technology = "oci" + Rpm Technology = "rpm" + Debian Technology = "deb" + Composer Technology = "composer" + Alpine Technology = "alpine" + Cpp Technology = "cpp" + NoTech Technology = "" ) // Alternative package types for some technologies @@ -86,6 +87,7 @@ var AllTechnologiesStrings = []string{ Nuget.String(), Dotnet.String(), Docker.String(), + HuggingFaceMl.String(), Oci.String(), Conan.String(), Cocoapods.String(), @@ -299,6 +301,10 @@ var technologiesData = map[Technology]TechData{ formal: "Docker", projectType: project.Docker, }, + HuggingFaceMl: { + formal: "Hugging Face", + xrayPackageType: "huggingfaceml", + }, Oci: {}, Rpm: {formal: "RPM"}, Debian: {formal: "Debian"},