diff --git a/cel/env.go b/cel/env.go index ea5597e2..3bfe4289 100644 --- a/cel/env.go +++ b/cel/env.go @@ -217,7 +217,7 @@ func (e *Env) Check(ast *Ast) (*Ast, *Issues) { chk, err := e.initChecker() if err != nil { errs := common.NewErrors(ast.Source()) - errs.ReportError(common.NoLocation, err.Error()) + errs.ReportErrorString(common.NoLocation, err.Error()) return nil, NewIssuesWithSourceInfo(errs, ast.NativeRep().SourceInfo()) } diff --git a/cel/library.go b/cel/library.go index 3dc8594b..e36590e2 100644 --- a/cel/library.go +++ b/cel/library.go @@ -731,7 +731,7 @@ var ( func timestampGetFullYear(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Year()) } @@ -739,7 +739,7 @@ func timestampGetFullYear(ts, tz ref.Val) ref.Val { func timestampGetMonth(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } // CEL spec indicates that the month should be 0-based, but the Time value // for Month() is 1-based. @@ -749,7 +749,7 @@ func timestampGetMonth(ts, tz ref.Val) ref.Val { func timestampGetDayOfYear(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.YearDay() - 1) } @@ -757,7 +757,7 @@ func timestampGetDayOfYear(ts, tz ref.Val) ref.Val { func timestampGetDayOfMonthZeroBased(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Day() - 1) } @@ -765,7 +765,7 @@ func timestampGetDayOfMonthZeroBased(ts, tz ref.Val) ref.Val { func timestampGetDayOfMonthOneBased(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Day()) } @@ -773,7 +773,7 @@ func timestampGetDayOfMonthOneBased(ts, tz ref.Val) ref.Val { func timestampGetDayOfWeek(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Weekday()) } @@ -781,7 +781,7 @@ func timestampGetDayOfWeek(ts, tz ref.Val) ref.Val { func timestampGetHours(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Hour()) } @@ -789,7 +789,7 @@ func timestampGetHours(ts, tz ref.Val) ref.Val { func timestampGetMinutes(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Minute()) } @@ -797,7 +797,7 @@ func timestampGetMinutes(ts, tz ref.Val) ref.Val { func timestampGetSeconds(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Second()) } @@ -805,7 +805,7 @@ func timestampGetSeconds(ts, tz ref.Val) ref.Val { func timestampGetMilliseconds(ts, tz ref.Val) ref.Val { t, err := inTimeZone(ts, tz) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(t.Nanosecond() / 1000000) } diff --git a/common/errors.go b/common/errors.go index 89570683..c8865df8 100644 --- a/common/errors.go +++ b/common/errors.go @@ -46,6 +46,11 @@ func (e *Errors) ReportError(l Location, format string, args ...any) { e.ReportErrorAtID(0, l, format, args...) } +// ReportErrorString records an error at a source location. +func (e *Errors) ReportErrorString(l Location, message string) { + e.ReportErrorAtID(0, l, "%s", message) +} + // ReportErrorAtID records an error at a source location and expression id. func (e *Errors) ReportErrorAtID(id int64, l Location, format string, args ...any) { e.numErrors++ diff --git a/common/types/err.go b/common/types/err.go index 9c9d9e21..ee1a76e7 100644 --- a/common/types/err.go +++ b/common/types/err.go @@ -62,6 +62,12 @@ func NewErr(format string, args ...any) ref.Val { return &Err{error: fmt.Errorf(format, args...)} } +// NewErr creates a new Err with the provided message. +// TODO: Audit the use of this function and standardize the error messages and codes. +func NewErrFromString(message string) ref.Val { + return &Err{error: errors.New(message)} +} + // NewErrWithNodeID creates a new Err described by the format string and args. // TODO: Audit the use of this function and standardize the error messages and codes. func NewErrWithNodeID(id int64, format string, args ...any) ref.Val { diff --git a/common/types/list.go b/common/types/list.go index ca47d39f..7e68a5da 100644 --- a/common/types/list.go +++ b/common/types/list.go @@ -243,7 +243,7 @@ func (l *baseList) Equal(other ref.Val) ref.Val { func (l *baseList) Get(index ref.Val) ref.Val { ind, err := IndexOrError(index) if err != nil { - return ValOrErr(index, err.Error()) + return ValOrErr(index, "%v", err) } if ind < 0 || ind >= l.size { return NewErr("index '%d' out of range in list size '%d'", ind, l.Size()) @@ -427,7 +427,7 @@ func (l *concatList) Equal(other ref.Val) ref.Val { func (l *concatList) Get(index ref.Val) ref.Val { ind, err := IndexOrError(index) if err != nil { - return ValOrErr(index, err.Error()) + return ValOrErr(index, "%v", err) } i := Int(ind) if i < l.prevList.Size().(Int) { diff --git a/common/types/object.go b/common/types/object.go index 8ba0af9f..5377bff8 100644 --- a/common/types/object.go +++ b/common/types/object.go @@ -151,7 +151,7 @@ func (o *protoObj) Get(index ref.Val) ref.Val { } fv, err := fd.GetFrom(o.value) if err != nil { - return NewErr(err.Error()) + return NewErrFromString(err.Error()) } return o.NativeToValue(fv) } diff --git a/ext/formatting.go b/ext/formatting.go index dbff613b..932d562e 100644 --- a/ext/formatting.go +++ b/ext/formatting.go @@ -434,7 +434,7 @@ func (stringFormatValidator) Validate(env *cel.Env, _ cel.ValidatorConfig, a *as // use a placeholder locale, since locale doesn't affect syntax _, err := parseFormatString(formatStr, formatCheck, formatCheck, "en_US") if err != nil { - iss.ReportErrorAtID(getErrorExprID(e.ID(), err), err.Error()) + iss.ReportErrorAtID(getErrorExprID(e.ID(), err), "%v", err) continue } seenArgs := formatCheck.argsRequested diff --git a/ext/guards.go b/ext/guards.go index ccede289..1461c041 100644 --- a/ext/guards.go +++ b/ext/guards.go @@ -24,28 +24,28 @@ import ( func intOrError(i int64, err error) ref.Val { if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Int(i) } func bytesOrError(bytes []byte, err error) ref.Val { if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.Bytes(bytes) } func stringOrError(str string, err error) ref.Val { if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.String(str) } func listStringOrError(strs []string, err error) ref.Val { if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return types.DefaultTypeAdapter.NativeToValue(strs) } diff --git a/ext/native.go b/ext/native.go index 83f36589..1c33def4 100644 --- a/ext/native.go +++ b/ext/native.go @@ -343,7 +343,7 @@ func (tp *nativeTypeProvider) NewValue(typeName string, fields map[string]ref.Va } fieldVal, err := val.ConvertToNative(refFieldDef.Type) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } refField := refVal.FieldByIndex(refFieldDef.Index) refFieldVal := reflect.ValueOf(fieldVal) @@ -450,7 +450,7 @@ func convertToCelType(refType reflect.Type) (*cel.Type, bool) { func (tp *nativeTypeProvider) newNativeObject(val any, refValue reflect.Value) ref.Val { valType, err := newNativeType(tp.options.fieldNameHandler, refValue.Type()) if err != nil { - return types.NewErr(err.Error()) + return types.NewErrFromString(err.Error()) } return &nativeObj{ Adapter: tp, diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index b89b2214..86b6025f 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -1120,7 +1120,7 @@ func TestAttributeStateTracking(t *testing.T) { } parsed, errors := p.Parse(src) if len(errors.GetErrors()) != 0 { - t.Fatalf(errors.ToDisplayString()) + t.Fatal(errors.ToDisplayString()) } cont := containers.DefaultContainer reg := newTestRegistry(t) @@ -1135,7 +1135,7 @@ func TestAttributeStateTracking(t *testing.T) { } checked, errors := checker.Check(parsed, src, env) if len(errors.GetErrors()) != 0 { - t.Fatalf(errors.ToDisplayString()) + t.Fatal(errors.ToDisplayString()) } in, err := NewActivation(tc.in) if err != nil { diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 00bf04dc..8582b6bc 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -1780,7 +1780,7 @@ func TestInterpreter_SetProto2PrimitiveFields(t *testing.T) { types.NewObjectType("google.expr.proto2.test.TestAllTypes"))) checked, errors := checker.Check(parsed, src, env) if len(errors.GetErrors()) != 0 { - t.Errorf(errors.ToDisplayString()) + t.Error(errors.ToDisplayString()) } attrs := NewAttributeFactory(cont, reg, reg) @@ -1829,7 +1829,7 @@ func TestInterpreter_MissingIdentInSelect(t *testing.T) { env.AddIdents(decls.NewVariable("a.b", types.DynType)) checked, errors := checker.Check(parsed, src, env) if len(errors.GetErrors()) != 0 { - t.Fatalf(errors.ToDisplayString()) + t.Fatal(errors.ToDisplayString()) } attrs := NewPartialAttributeFactory(cont, reg, reg) @@ -1885,7 +1885,7 @@ func TestInterpreter_TypeConversionOpt(t *testing.T) { env := newTestEnv(t, cont, reg) checked, errors := checker.Check(parsed, src, env) if len(errors.GetErrors()) != 0 { - t.Fatalf(errors.ToDisplayString()) + t.Fatal(errors.ToDisplayString()) } attrs := NewAttributeFactory(cont, reg, reg) interp := newStandardInterpreter(t, cont, reg, reg, attrs) diff --git a/parser/errors.go b/parser/errors.go index 93ae7a3a..c3cec01a 100644 --- a/parser/errors.go +++ b/parser/errors.go @@ -15,8 +15,6 @@ package parser import ( - "fmt" - "github.com/google/cel-go/common" ) @@ -31,11 +29,11 @@ func (e *parseErrors) errorCount() int { } func (e *parseErrors) internalError(message string) { - e.errs.ReportErrorAtID(0, common.NoLocation, message) + e.errs.ReportErrorAtID(0, common.NoLocation, "%s", message) } func (e *parseErrors) syntaxError(l common.Location, message string) { - e.errs.ReportErrorAtID(0, l, fmt.Sprintf("Syntax error: %s", message)) + e.errs.ReportErrorAtID(0, l, "Syntax error: %s", message) } func (e *parseErrors) reportErrorAtID(id int64, l common.Location, message string, args ...any) { diff --git a/parser/parser.go b/parser/parser.go index a77213b1..cbe80051 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -970,7 +970,7 @@ func (p *parser) expandMacro(exprID int64, function string, target ast.Expr, arg loc = p.helper.getLocation(exprID) } p.helper.deleteID(exprID) - return p.reportError(loc, err.Message), true + return p.reportError(loc, "%s", err.Message), true } // A nil value from the macro indicates that the macro implementation decided that // an expansion should not be performed. diff --git a/parser/parser_test.go b/parser/parser_test.go index ed72b378..887df6c3 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -2039,7 +2039,7 @@ func TestParse(t *testing.T) { if tc.E == "" { t.Fatalf("Unexpected errors: %v", actualErr) } else if !test.Compare(actualErr, tc.E) { - t.Fatalf(test.DiffMessage("Error mismatch", actualErr, tc.E)) + t.Fatal(test.DiffMessage("Error mismatch", actualErr, tc.E)) } return } else if tc.E != "" { diff --git a/repl/typefmt.go b/repl/typefmt.go index a25ff679..cf1b6ed5 100644 --- a/repl/typefmt.go +++ b/repl/typefmt.go @@ -312,7 +312,7 @@ func ParseType(t string) (*exprpb.Type, error) { for i, e := range errs { msgs[i] = e.Error() } - err = fmt.Errorf("errors parsing type:\n" + strings.Join(msgs, "\n")) + err = fmt.Errorf("errors parsing type:\n%s", strings.Join(msgs, "\n")) } return result, err