Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions backend/go/cloud-proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/mudler/xlog"

"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/httpclient"
)
Expand Down Expand Up @@ -145,7 +146,7 @@ func resolveAPIKey(envName, filePath string) (string, error) {
func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err error) {
cfg := c.cfg.Load()
if cfg == nil {
return nil, errors.New("cloud-proxy: model not loaded")
return nil, grpcerrors.ModelNotLoaded("cloud-proxy")
}
if cfg.mode != modeTranslate {
return nil, fmt.Errorf("cloud-proxy: Predict only valid in translate mode (have %s)", cfg.mode)
Expand Down Expand Up @@ -175,7 +176,7 @@ func (c *CloudProxy) PredictRich(opts *pb.PredictOptions) (reply *pb.Reply, err
func (c *CloudProxy) PredictStreamRich(opts *pb.PredictOptions, results chan<- *pb.Reply) (err error) {
cfg := c.cfg.Load()
if cfg == nil {
return errors.New("cloud-proxy: model not loaded")
return grpcerrors.ModelNotLoaded("cloud-proxy")
}
if cfg.mode != modeTranslate {
return fmt.Errorf("cloud-proxy: PredictStream only valid in translate mode (have %s)", cfg.mode)
Expand Down Expand Up @@ -269,7 +270,7 @@ func (c *CloudProxy) Forward(ctx context.Context, in <-chan *pb.ForwardRequest,

cfg := c.cfg.Load()
if cfg == nil {
return errors.New("cloud-proxy: model not loaded")
return grpcerrors.ModelNotLoaded("cloud-proxy")
}
if cfg.mode != modePassthrough {
return fmt.Errorf("cloud-proxy: Forward only valid in passthrough mode (have %s)", cfg.mode)
Expand Down
5 changes: 3 additions & 2 deletions backend/go/parakeet-cpp/goparakeetcpp.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"github.com/go-audio/wav"
"github.com/mudler/LocalAI/pkg/grpc/base"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/LocalAI/pkg/utils"
"github.com/mudler/xlog"
Expand Down Expand Up @@ -230,7 +231,7 @@ func (p *ParakeetCpp) runBatch(reqs []*batchRequest) {
// (L2).
func (p *ParakeetCpp) AudioTranscription(ctx context.Context, opts *pb.TranscriptRequest) (pb.TranscriptResult, error) {
if p.ctxPtr == 0 {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: model not loaded")
return pb.TranscriptResult{}, grpcerrors.ModelNotLoaded("parakeet-cpp")
}
if opts.Dst == "" {
return pb.TranscriptResult{}, errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
Expand Down Expand Up @@ -351,7 +352,7 @@ func (p *ParakeetCpp) AudioTranscriptionStream(ctx context.Context, opts *pb.Tra
defer close(results)

if p.ctxPtr == 0 {
return errors.New("parakeet-cpp: model not loaded")
return grpcerrors.ModelNotLoaded("parakeet-cpp")
}
if opts.Dst == "" {
return errors.New("parakeet-cpp: TranscriptRequest.dst (audio path) is required")
Expand Down
56 changes: 44 additions & 12 deletions core/services/nodes/inflight.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"time"

"github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
pb "github.com/mudler/LocalAI/pkg/grpc/proto"
"github.com/mudler/xlog"
ggrpc "google.golang.org/grpc"
Expand Down Expand Up @@ -64,64 +65,95 @@ func (c *InFlightTrackingClient) track(ctx context.Context) func() {
}
}

// reconcile self-heals stale routing: when a backend reports that the model is
// no longer loaded (the process survived but the model was evicted, while the
// registry still lists it as loaded), it drops the replica row so the next
// request triggers a fresh load instead of routing back here. Without this the
// model stays unreachable until the controller restarts. The original error is
// returned unchanged.
func (c *InFlightTrackingClient) reconcile(err error) error {
if !grpcerrors.IsModelNotLoaded(err) {
return err
}
rmCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if rmErr := c.registry.RemoveNodeModel(rmCtx, c.nodeID, c.modelName, c.replicaIndex); rmErr != nil {
xlog.Warn("Failed to drop stale replica after model-not-loaded",
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex, "error", rmErr)
} else {
xlog.Warn("Backend reports model not loaded; dropped stale replica so the next request reloads",
"node", c.nodeID, "model", c.modelName, "replica", c.replicaIndex)
}
return err
}

// --- Tracked inference methods ---

func (c *InFlightTrackingClient) Predict(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.Reply, error) {
defer c.track(ctx)()
return c.Backend.Predict(ctx, in, opts...)
reply, err := c.Backend.Predict(ctx, in, opts...)
return reply, c.reconcile(err)
}

func (c *InFlightTrackingClient) PredictStream(ctx context.Context, in *pb.PredictOptions, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.PredictStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.PredictStream(ctx, in, f, opts...))
}

func (c *InFlightTrackingClient) Embeddings(ctx context.Context, in *pb.PredictOptions, opts ...ggrpc.CallOption) (*pb.EmbeddingResult, error) {
defer c.track(ctx)()
return c.Backend.Embeddings(ctx, in, opts...)
res, err := c.Backend.Embeddings(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.GenerateImage(ctx, in, opts...)
res, err := c.Backend.GenerateImage(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) GenerateVideo(ctx context.Context, in *pb.GenerateVideoRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.GenerateVideo(ctx, in, opts...)
res, err := c.Backend.GenerateVideo(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) TTS(ctx context.Context, in *pb.TTSRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.TTS(ctx, in, opts...)
res, err := c.Backend.TTS(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) TTSStream(ctx context.Context, in *pb.TTSRequest, f func(reply *pb.Reply), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.TTSStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.TTSStream(ctx, in, f, opts...))
}

func (c *InFlightTrackingClient) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest, opts ...ggrpc.CallOption) (*pb.Result, error) {
defer c.track(ctx)()
return c.Backend.SoundGeneration(ctx, in, opts...)
res, err := c.Backend.SoundGeneration(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest, opts ...ggrpc.CallOption) (*pb.TranscriptResult, error) {
defer c.track(ctx)()
return c.Backend.AudioTranscription(ctx, in, opts...)
res, err := c.Backend.AudioTranscription(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) AudioTranscriptionStream(ctx context.Context, in *pb.TranscriptRequest, f func(chunk *pb.TranscriptStreamResponse), opts ...ggrpc.CallOption) error {
defer c.track(ctx)()
return c.Backend.AudioTranscriptionStream(ctx, in, f, opts...)
return c.reconcile(c.Backend.AudioTranscriptionStream(ctx, in, f, opts...))
}

func (c *InFlightTrackingClient) Detect(ctx context.Context, in *pb.DetectOptions, opts ...ggrpc.CallOption) (*pb.DetectResponse, error) {
defer c.track(ctx)()
return c.Backend.Detect(ctx, in, opts...)
res, err := c.Backend.Detect(ctx, in, opts...)
return res, c.reconcile(err)
}

func (c *InFlightTrackingClient) Rerank(ctx context.Context, in *pb.RerankRequest, opts ...ggrpc.CallOption) (*pb.RerankResult, error) {
defer c.track(ctx)()
return c.Backend.Rerank(ctx, in, opts...)
res, err := c.Backend.Rerank(ctx, in, opts...)
return res, c.reconcile(err)
}
37 changes: 37 additions & 0 deletions core/services/nodes/inflight_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,17 @@ type fakeInFlightTracker struct {
mu sync.Mutex
increments int
decrements int
removed int
incrementErr error
}

func (f *fakeInFlightTracker) RemoveNodeModel(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
f.removed++
return nil
}

func (f *fakeInFlightTracker) IncrementInFlight(_ context.Context, _, _ string, _ int) error {
f.mu.Lock()
defer f.mu.Unlock()
Expand Down Expand Up @@ -295,4 +303,33 @@ var _ = Describe("InFlightTrackingClient", func() {
Expect(tracker.decrements).To(Equal(1))
})
})

Describe("stale model reload (self-heal)", func() {
It("removes the replica when the backend reports the model is not loaded", func() {
backend.predictErr = fmt.Errorf("parakeet-cpp: model not loaded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})

It("keeps the replica on an unrelated error", func() {
backend.predictErr = fmt.Errorf("context deadline exceeded")
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})

It("does not remove on success", func() {
_, err := client.Predict(context.Background(), &pb.PredictOptions{})
Expect(err).ToNot(HaveOccurred())
Expect(tracker.removed).To(Equal(0))
})

It("self-heals on a streamed call too", func() {
backend.streamErr = fmt.Errorf("whisper: model not loaded")
err := client.PredictStream(context.Background(), &pb.PredictOptions{}, func(*pb.Reply) {})
Expect(err).To(HaveOccurred())
Expect(tracker.removed).To(Equal(1))
})
})
})
3 changes: 3 additions & 0 deletions core/services/nodes/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ type ModelLookup interface {
type InFlightTracker interface {
IncrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
DecrementInFlight(ctx context.Context, nodeID, modelName string, replicaIndex int) error
// RemoveNodeModel drops a stale replica row so the next request reloads the
// model instead of routing back to a node where it is no longer loaded.
RemoveNodeModel(ctx context.Context, nodeID, modelName string, replicaIndex int) error
}

// NodeManager is used by HTTP endpoints for node registration and lifecycle.
Expand Down
35 changes: 35 additions & 0 deletions pkg/grpc/grpcerrors/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Package grpcerrors defines well-known error signals shared between backends
// (which produce them) and the router (which consumes them). Go error types do
// not survive the gRPC boundary, so these conditions are carried as gRPC status
// codes and detected via the code rather than by matching the error message.
package grpcerrors

import (
"strings"

"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

// ModelNotLoaded returns the canonical error a backend returns when it has no
// model loaded for the request. It carries codes.FailedPrecondition so callers
// can detect it across the gRPC boundary without matching the message string.
func ModelNotLoaded(backend string) error {
return status.Errorf(codes.FailedPrecondition, "%s: model not loaded", backend)
}

// IsModelNotLoaded reports whether err signals that the backend has no model
// loaded. It prefers the typed gRPC status code (FailedPrecondition) and falls
// back to the message for backends that have not yet adopted ModelNotLoaded.
//
// Acting on a false positive is harmless: the only consequence upstream is that
// the model is reloaded, which is idempotent.
func IsModelNotLoaded(err error) bool {
if err == nil {
return false
}
if status.Code(err) == codes.FailedPrecondition {
return true
}
return strings.Contains(strings.ToLower(err.Error()), "model not loaded")
}
37 changes: 37 additions & 0 deletions pkg/grpc/grpcerrors/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package grpcerrors_test

import (
"errors"
"testing"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"

"github.com/mudler/LocalAI/pkg/grpc/grpcerrors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

func TestGRPCErrors(t *testing.T) {
RegisterFailHandler(Fail)
RunSpecs(t, "grpcerrors test suite")
}

var _ = Describe("grpcerrors", func() {
DescribeTable("IsModelNotLoaded",
func(err error, want bool) {
Expect(grpcerrors.IsModelNotLoaded(err)).To(Equal(want))
},
Entry("nil", nil, false),
Entry("typed via constructor", grpcerrors.ModelNotLoaded("parakeet-cpp"), true),
Entry("typed code only", status.Error(codes.FailedPrecondition, "anything"), true),
Entry("legacy message (Unknown code)", errors.New("parakeet-cpp: model not loaded"), true),
Entry("legacy message mixed case", errors.New("Backend: Model Not Loaded"), true),
Entry("unrelated error", errors.New("context deadline exceeded"), false),
Entry("unrelated grpc code", status.Error(codes.Unavailable, "connection refused"), false),
)

It("ModelNotLoaded carries FailedPrecondition", func() {
Expect(status.Code(grpcerrors.ModelNotLoaded("whisper"))).To(Equal(codes.FailedPrecondition))
})
})
Loading