// Copyright 2020-2025 Buf Technologies, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
	"bytes"
	"context"
	"encoding/hex"
	"fmt"
	"go/format"
	"io"
	"sort"
	"strings"

	"buf.build/go/app/appcmd"
	"buf.build/go/app/appext"
	"github.com/bufbuild/buf/private/pkg/shake256"
	"github.com/bufbuild/buf/private/pkg/slogapp"
	"github.com/spf13/pflag"
)

const (
	programName = "buf-legacyfederation-go-data"
	pkgFlagName = "package"
)

func main() {
	appcmd.Main(context.Background(), newCommand())
}

func newCommand() *appcmd.Command {
	flags := newFlags()
	builder := appext.NewBuilder(
		programName,
		appext.BuilderWithLoggerProvider(slogapp.LoggerProvider),
	)
	return &appcmd.Command{
		Use:   programName,
		Short: "Send a newline-separated file via stdin",
		Args:  appcmd.NoArgs,
		Run: builder.NewRunFunc(
			func(ctx context.Context, container appext.Container) error {
				return run(ctx, container, flags)
			},
		),
		BindFlags: flags.Bind,
	}
}

type flags struct {
	Pkg string
}

func newFlags() *flags {
	return &flags{}
}

func (f *flags) Bind(flagSet *pflag.FlagSet) {
	flagSet.StringVar(
		&f.Pkg,
		pkgFlagName,
		"",
		"The name of the generated package.",
	)
}

func run(ctx context.Context, container appext.Container, flags *flags) error {
	if flags.Pkg == "" {
		return appcmd.NewInvalidArgumentErrorf("--%s is required", pkgFlagName)
	}
	data, err := io.ReadAll(container.Stdin())
	if err != nil {
		return err
	}
	hexEncodedDigests, err := getSortedHexEncodedDigests(data)
	if err != nil {
		return err
	}
	golangFileData, err := getGolangFileData(
		flags.Pkg,
		hexEncodedDigests,
	)
	if err != nil {
		return err
	}
	_, err = container.Stdout().Write(golangFileData)
	return err
}

func getSortedHexEncodedDigests(data []byte) ([]string, error) {
	lines := make(map[string]struct{})
	for _, line := range strings.Split(string(data), "\n") {
		line = strings.TrimSpace(line)
		if line != "" {
			lines[line] = struct{}{}
		}
	}
	hexEncodedDigests := make([]string, 0, len(lines))
	for line := range lines {
		digest, err := shake256.NewDigestForContent(strings.NewReader(line))
		if err != nil {
			return nil, err
		}
		hexEncodedDigests = append(hexEncodedDigests, hex.EncodeToString(digest.Value()))
	}
	sort.Strings(hexEncodedDigests)
	return hexEncodedDigests, nil
}

func getGolangFileData(
	packageName string,
	hexEncodedDigests []string,
) ([]byte, error) {
	buffer := bytes.NewBuffer(nil)
	p := func(s string) {
		_, _ = buffer.WriteString(s)
	}

	p(`// Code generated by `)
	p(programName)
	p(`. DO NOT EDIT.`)
	p("\n\n")
	p(`package `)
	p(packageName)
	p("\n\n")
	p(`import (
	"encoding/hex"
	"strings"

	"github.com/bufbuild/buf/private/pkg/shake256"
)`)
	p("\n\n")
	p(`var (`)
	p("\n")
	p(`// hostnameHexEncodedDigests are the shake256 digests of the hostnames that are allowed to use legacy federation.
	hostnameHexEncodedDigests = map[string]struct{}{`)
	p("\n")
	for _, hexEncodedDigest := range hexEncodedDigests {
		p(`"`)
		p(hexEncodedDigest)
		p(`": {},`)
		p("\n")
	}
	p(`}`)
	p(`)`)
	p("\n\n")
	p(`// Exists returns true if the hostname is allowed to use legacy federation.
func Exists(hostname string) (bool, error) {
	if hostname == "" {
		return false, nil
	}
	digest, err := shake256.NewDigestForContent(strings.NewReader(hostname))
	if err != nil {
		return false, err
	}
	hexEncodedDigest := hex.EncodeToString(digest.Value())
	_, ok := hostnameHexEncodedDigests[hexEncodedDigest]
	return ok, nil
}`)
	formatted, err := format.Source(buffer.Bytes())
	if err != nil {
		return nil, fmt.Errorf("could not format: %w\n%s", err, buffer.String())
	}
	return formatted, nil
}
