package osvmatcher

import (
	"context"
	"errors"
	"reflect"
	"testing"
	"time"

	"github.com/google/osv-scalibr/extractor"
	"github.com/google/osv-scalibr/purl"
	"github.com/ossf/osv-schema/bindings/go/osvschema"
	"osv.dev/bindings/go/osvdev"
)

func TestOSVMatcher_MatchVulnerabilities(t *testing.T) {
	t.Parallel()

	type fields struct {
		Client              osvdev.OSVClient
		InitialQueryTimeout time.Duration
	}

	type args struct {
		pkgs []*extractor.Package
	}

	tests := []struct {
		name    string
		fields  fields
		args    args
		want    [][]*osvschema.Vulnerability
		wantErr error
	}{
		{
			name: "Timeout returns deadline exceeded error (http.Client code)",
			fields: fields{
				Client: *osvdev.DefaultClient(),
				// Long enough to not timeout until we enter the http client code
				InitialQueryTimeout: 1 * time.Millisecond,
			},
			args: args{
				pkgs: []*extractor.Package{
					{
						Name:     "stdlib",
						Version:  "1.22.0",
						PURLType: purl.TypeGolang,
					},
				},
			},
			want:    nil,
			wantErr: context.DeadlineExceeded,
		},
		{
			name: "Timeout returns deadline exceeded error (osv.dev code)",
			fields: fields{
				Client: *osvdev.DefaultClient(),
				// Short enough to test timeouts before reaching the http client
				InitialQueryTimeout: 100 * time.Nanosecond,
			},
			args: args{
				pkgs: []*extractor.Package{
					{
						Name:     "stdlib",
						Version:  "1.22.0",
						PURLType: purl.TypeGolang,
					},
				},
			},
			want:    nil,
			wantErr: context.DeadlineExceeded,
		},
	}

	for i := range tests {
		tt := tests[i]
		t.Run(tt.name, func(t *testing.T) {
			t.Parallel()

			matcher := &OSVMatcher{
				Client:              tt.fields.Client,
				InitialQueryTimeout: tt.fields.InitialQueryTimeout,
			}

			got, err := matcher.MatchVulnerabilities(t.Context(), tt.args.pkgs)
			if !errors.Is(err, tt.wantErr) {
				t.Errorf("OSVMatcher.MatchVulnerabilities() error = %v, wantErr %v", err, tt.wantErr)
			}

			if !reflect.DeepEqual(got, tt.want) {
				t.Errorf("OSVMatcher.MatchVulnerabilities() = %v, want %v", got, tt.want)
			}
		})
	}
}
