Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ public abstract class AbstractPythonCodegen extends DefaultCodegen implements Co
private final Logger LOGGER = LoggerFactory.getLogger(AbstractPythonCodegen.class);

public static final String MAP_NUMBER_TO = "mapNumberTo";
public static final String PYDANTIC = "pydantic";
public static final Set<String> SUPPORTED_NUMBER_MAPPINGS =
Set.of("Union[StrictFloat, StrictInt]", "StrictFloat", "float");

protected String packageName = "openapi_client";
@Setter protected String packageVersion = "1.0.0";
Expand Down Expand Up @@ -985,16 +988,16 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
codegenProperties = model.getComposedSchemas().getOneOf();
moduleImports.add("typing", "Any");
moduleImports.add("typing", "List");
moduleImports.add("pydantic", "Field");
moduleImports.add("pydantic", "StrictStr");
moduleImports.add("pydantic", "ValidationError");
moduleImports.add("pydantic", "field_validator");
moduleImports.add(PYDANTIC, "Field");
moduleImports.add(PYDANTIC, "StrictStr");
moduleImports.add(PYDANTIC, "ValidationError");
moduleImports.add(PYDANTIC, "field_validator");
} else if (!model.anyOf.isEmpty()) { // anyOF
codegenProperties = model.getComposedSchemas().getAnyOf();
moduleImports.add("pydantic", "Field");
moduleImports.add("pydantic", "StrictStr");
moduleImports.add("pydantic", "ValidationError");
moduleImports.add("pydantic", "field_validator");
moduleImports.add(PYDANTIC, "Field");
moduleImports.add(PYDANTIC, "StrictStr");
moduleImports.add(PYDANTIC, "ValidationError");
moduleImports.add(PYDANTIC, "field_validator");
} else { // typical model
codegenProperties = model.vars;

Expand Down Expand Up @@ -1029,7 +1032,7 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {

// if pydantic model
if (!model.isEnum) {
moduleImports.add("pydantic", "ConfigDict");
moduleImports.add(PYDANTIC, "ConfigDict");
}

//loop through properties/schemas to set up typing, pydantic
Expand All @@ -1054,7 +1057,7 @@ private ModelsMap postProcessModelsMap(ModelsMap objs) {
if (!StringUtils.isEmpty(model.parent)) {
modelImports.add(model.parent);
} else if (!model.isEnum) {
moduleImports.add("pydantic", "BaseModel");
moduleImports.add(PYDANTIC, "BaseModel");
}

// set enum type in extensions and update `name` in enumVars
Expand Down Expand Up @@ -1155,12 +1158,10 @@ private PythonType getPydanticType(CodegenProperty cp,
}

public void setMapNumberTo(String mapNumberTo) {
if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo)
|| "StrictFloat".equals(mapNumberTo)
|| "float".equals(mapNumberTo)) {
if (SUPPORTED_NUMBER_MAPPINGS.contains(mapNumberTo)) {
this.mapNumberTo = mapNumberTo;
} else {
throw new IllegalArgumentException("mapNumberTo value must be Union[StrictFloat, StrictInt], StrictFloat or float");
throw new IllegalArgumentException(String.format(Locale.ROOT, "mapNumberTo supports %s", SUPPORTED_NUMBER_MAPPINGS));
}
}

Expand Down Expand Up @@ -1721,7 +1722,7 @@ private String asTypeConstraint(PythonImports imports, boolean withAnnotations)
}

if (fieldParams.size() > 0) {
imports.add("pydantic", "Field");
imports.add(PYDANTIC, "Field");
imports.add("typing_extensions", "Annotated");
currentType = "Annotated[" + currentType + ", Field(" + StringUtils.join(fieldParams, ", ") + ")]";
}
Expand Down Expand Up @@ -1761,7 +1762,7 @@ public String asTypeValue(PythonImports imports) {
ants.add(ans);
}

imports.add("pydantic", "Field");
imports.add(PYDANTIC, "Field");
typeValue = "Field(" + StringUtils.join(ants, ", ") + ")";
return typeValue;
}
Expand Down Expand Up @@ -1827,6 +1828,15 @@ public boolean isEmpty() {
}

class PydanticType {

private static final String LESS_THAN = "lt";
private static final String GREATER_THAN = "gt";
private static final String GREATER_OR_EQUAL_TO = "ge";
private static final String LESS_OR_EQUAL_TO = "le";
private static final String TYPING = "typing";

private static final String DECIMAL = "Decimal";

private Set<String> modelImports;
private Set<String> exampleImports;
private Set<String> postponedModelImports;
Expand Down Expand Up @@ -1868,10 +1878,10 @@ private PythonType arrayType(IJsonSchemaValidationProperties cp) {
//pt.setType("Set");
//moduleImports.add("typing", "Set");
pt.setType("List");
moduleImports.add("typing", "List");
moduleImports.add(TYPING, "List");
} else {
pt.setType("List");
moduleImports.add("typing", "List");
moduleImports.add(TYPING, "List");
}
pt.addTypeParam(collectionItemType(cp.getItems()));
return pt;
Expand All @@ -1880,7 +1890,7 @@ private PythonType arrayType(IJsonSchemaValidationProperties cp) {
private PythonType collectionItemType(CodegenProperty itemCp) {
PythonType itemPt = getType(itemCp);
if (itemCp != null && !itemPt.type.equals("Any") && itemCp.isNullable) {
moduleImports.add("typing", "Optional");
moduleImports.add(TYPING, "Optional");
PythonType opt = new PythonType("Optional");
opt.addTypeParam(itemPt);
itemPt = opt;
Expand All @@ -1903,24 +1913,24 @@ private PythonType stringType(IJsonSchemaValidationProperties cp) {
}

if (cp.getPattern() != null) {
moduleImports.add("pydantic", "field_validator");
moduleImports.add(PYDANTIC, "field_validator");
// use validator instead as regex doesn't support flags, e.g. IGNORECASE
//fieldCustomization.add(Locale.ROOT, String.format(Locale.ROOT, "regex=r'%s'", cp.getPattern()));
}
return pt;
} else {
if ("password".equals(cp.getFormat())) { // TODO avoid using format, use `is` boolean flag instead
moduleImports.add("pydantic", "SecretStr");
moduleImports.add(PYDANTIC, "SecretStr");
return new PythonType("SecretStr");
} else {
moduleImports.add("pydantic", "StrictStr");
moduleImports.add(PYDANTIC, "StrictStr");
return new PythonType("StrictStr");
}
}
}

private PythonType mapType(IJsonSchemaValidationProperties cp) {
moduleImports.add("typing", "Dict");
moduleImports.add(TYPING, "Dict");
PythonType pt = new PythonType("Dict");
pt.addTypeParam(new PythonType("str"));
pt.addTypeParam(collectionItemType(cp.getItems()));
Expand All @@ -1935,20 +1945,20 @@ private PythonType numberType(IJsonSchemaValidationProperties cp) {
// e.g. confloat(ge=10, le=100, strict=True)
if (cp.getMaximum() != null) {
if (cp.getExclusiveMaximum()) {
floatt.constrain("lt", cp.getMaximum(), false);
intt.constrain("lt", (int) Math.ceil(Double.valueOf(cp.getMaximum()))); // e.g. < 7.59 => < 8
floatt.constrain(LESS_THAN, cp.getMaximum(), false);
intt.constrain(LESS_THAN, (int) Math.ceil(Double.valueOf(cp.getMaximum()))); // e.g. < 7.59 => < 8
} else {
floatt.constrain("le", cp.getMaximum(), false);
intt.constrain("le", (int) Math.floor(Double.valueOf(cp.getMaximum()))); // e.g. <= 7.59 => <= 7
floatt.constrain(LESS_OR_EQUAL_TO, cp.getMaximum(), false);
intt.constrain(LESS_OR_EQUAL_TO, (int) Math.floor(Double.valueOf(cp.getMaximum()))); // e.g. <= 7.59 => <= 7
}
}
if (cp.getMinimum() != null) {
if (cp.getExclusiveMinimum()) {
floatt.constrain("gt", cp.getMinimum(), false);
intt.constrain("gt", (int) Math.floor(Double.valueOf(cp.getMinimum()))); // e.g. > 7.59 => > 7
floatt.constrain(GREATER_THAN, cp.getMinimum(), false);
intt.constrain(GREATER_THAN, (int) Math.floor(Double.valueOf(cp.getMinimum()))); // e.g. > 7.59 => > 7
} else {
floatt.constrain("ge", cp.getMinimum(), false);
intt.constrain("ge", (int) Math.ceil(Double.valueOf(cp.getMinimum()))); // e.g. >= 7.59 => >= 8
floatt.constrain(GREATER_OR_EQUAL_TO, cp.getMinimum(), false);
intt.constrain(GREATER_OR_EQUAL_TO, (int) Math.ceil(Double.valueOf(cp.getMinimum()))); // e.g. >= 7.59 => >= 8
}
}
if (cp.getMultipleOf() != null) {
Expand All @@ -1959,7 +1969,7 @@ private PythonType numberType(IJsonSchemaValidationProperties cp) {
floatt.constrain("strict", true);
intt.constrain("strict", true);

moduleImports.add("typing", "Union");
moduleImports.add(TYPING, "Union");
PythonType pt = new PythonType("Union");
pt.addTypeParam(floatt);
pt.addTypeParam(intt);
Expand All @@ -1972,15 +1982,15 @@ private PythonType numberType(IJsonSchemaValidationProperties cp) {
}
} else {
if ("Union[StrictFloat, StrictInt]".equals(mapNumberTo)) {
moduleImports.add("typing", "Union");
moduleImports.add("pydantic", "StrictFloat");
moduleImports.add("pydantic", "StrictInt");
moduleImports.add(TYPING, "Union");
moduleImports.add(PYDANTIC, "StrictFloat");
moduleImports.add(PYDANTIC, "StrictInt");
PythonType pt = new PythonType("Union");
pt.addTypeParam(new PythonType("StrictFloat"));
pt.addTypeParam(new PythonType("StrictInt"));
return pt;
} else if ("StrictFloat".equals(mapNumberTo)) {
moduleImports.add("pydantic", "StrictFloat");
moduleImports.add(PYDANTIC, "StrictFloat");
return new PythonType("StrictFloat");
} else {
return new PythonType("float");
Expand All @@ -1993,26 +2003,10 @@ private PythonType intType(IJsonSchemaValidationProperties cp) {
PythonType pt = new PythonType("int");
// e.g. conint(ge=10, le=100, strict=True)
pt.constrain("strict", true);
if (cp.getMaximum() != null) {
if (cp.getExclusiveMaximum()) {
pt.constrain("lt", cp.getMaximum(), false);
} else {
pt.constrain("le", cp.getMaximum(), false);
}
}
if (cp.getMinimum() != null) {
if (cp.getExclusiveMinimum()) {
pt.constrain("gt", cp.getMinimum(), false);
} else {
pt.constrain("ge", cp.getMinimum(), false);
}
}
if (cp.getMultipleOf() != null) {
pt.constrain("multiple_of", cp.getMultipleOf());
}
applyConstraints(pt, cp);
return pt;
} else {
moduleImports.add("pydantic", "StrictInt");
moduleImports.add(PYDANTIC, "StrictInt");
return new PythonType("StrictInt");
}
}
Expand All @@ -2034,19 +2028,19 @@ private PythonType binaryType(IJsonSchemaValidationProperties cp) {
strt.constrain("min_length", cp.getMinLength());
}
if (cp.getPattern() != null) {
moduleImports.add("pydantic", "field_validator");
moduleImports.add(PYDANTIC, "field_validator");
// use validator instead as regex doesn't support flags, e.g. IGNORECASE
//fieldCustomization.add(Locale.ROOT, String.format(Locale.ROOT, "regex=r'%s'", cp.getPattern()));
}

moduleImports.add("typing", "Union");
moduleImports.add(TYPING, "Union");

PythonType pt = new PythonType("Union");
pt.addTypeParam(bytest);
pt.addTypeParam(strt);

if (cp.getIsBinary()) {
moduleImports.add("typing", "Tuple");
moduleImports.add(TYPING, "Tuple");

PythonType tt = new PythonType("Tuple");
// this string is a filename, not a validated value
Expand All @@ -2059,16 +2053,16 @@ private PythonType binaryType(IJsonSchemaValidationProperties cp) {
return pt;
} else {
// same as above which has validation
moduleImports.add("pydantic", "StrictBytes");
moduleImports.add("pydantic", "StrictStr");
moduleImports.add("typing", "Union");
moduleImports.add(PYDANTIC, "StrictBytes");
moduleImports.add(PYDANTIC, "StrictStr");
moduleImports.add(TYPING, "Union");

PythonType pt = new PythonType("Union");
pt.addTypeParam(new PythonType("StrictBytes"));
pt.addTypeParam(new PythonType("StrictStr"));

if (cp.getIsBinary()) {
moduleImports.add("typing", "Tuple");
moduleImports.add(TYPING, "Tuple");

PythonType tt = new PythonType("Tuple");
tt.addTypeParam(new PythonType("StrictStr"));
Expand All @@ -2082,41 +2076,25 @@ private PythonType binaryType(IJsonSchemaValidationProperties cp) {
}

private PythonType boolType(IJsonSchemaValidationProperties cp) {
moduleImports.add("pydantic", "StrictBool");
moduleImports.add(PYDANTIC, "StrictBool");
return new PythonType("StrictBool");
}

private PythonType decimalType(IJsonSchemaValidationProperties cp) {
PythonType pt = new PythonType("Decimal");
moduleImports.add("decimal", "Decimal");
PythonType pt = new PythonType(DECIMAL);
moduleImports.add("decimal", DECIMAL);

if (cp.getHasValidation()) {
// e.g. condecimal(ge=10, le=100, strict=True)
pt.constrain("strict", true);
if (cp.getMaximum() != null) {
if (cp.getExclusiveMaximum()) {
pt.constrain("gt", cp.getMaximum(), false);
} else {
pt.constrain("ge", cp.getMaximum(), false);
}
}
if (cp.getMinimum() != null) {
if (cp.getExclusiveMinimum()) {
pt.constrain("lt", cp.getMinimum(), false);
} else {
pt.constrain("le", cp.getMinimum(), false);
}
}
if (cp.getMultipleOf() != null) {
pt.constrain("multiple_of", cp.getMultipleOf());
}
applyConstraints(pt, cp);
}

return pt;
}

private PythonType anyType(IJsonSchemaValidationProperties cp) {
moduleImports.add("typing", "Any");
moduleImports.add(TYPING, "Any");
return new PythonType("Any");
}

Expand Down Expand Up @@ -2148,16 +2126,16 @@ private PythonType fromCommon(IJsonSchemaValidationProperties cp) {
if (cp == null) {
// if codegen property (e.g. map/dict of undefined type) is null, default to string
LOGGER.warn("Codegen property is null (e.g. map/dict of undefined type). Default to typing.Any.");
moduleImports.add("typing", "Any");
moduleImports.add(TYPING, "Any");
return new PythonType("Any");
}

if (cp.getIsEnum()) {
moduleImports.add("pydantic", "field_validator");
moduleImports.add(PYDANTIC, "field_validator");
}

if (cp.getPattern() != null) {
moduleImports.add("pydantic", "field_validator");
moduleImports.add(PYDANTIC, "field_validator");
}

if (cp.getIsArray()) {
Expand Down Expand Up @@ -2199,7 +2177,7 @@ private PythonType getType(CodegenProperty cp) {
also need to put cp.isEnum check after isArray, isMap check
if (cp.isEnum) {
// use Literal for inline enum
moduleImports.add("typing", "Literal");
moduleImports.add(TYPING, "Literal");
List<String> values = new ArrayList<>();
List<Map<String, Object>> enumVars = (List<Map<String, Object>>) cp.allowableValues.get("enumVars");
if (enumVars != null) {
Expand Down Expand Up @@ -2248,7 +2226,7 @@ private PythonType getType(CodegenProperty cp) {

private String finalizeType(CodegenProperty cp, PythonType pt) {
if (!cp.required || cp.isNullable) {
moduleImports.add("typing", "Optional");
moduleImports.add(TYPING, "Optional");
PythonType opt = new PythonType("Optional");
opt.addTypeParam(pt);
pt = opt;
Expand Down Expand Up @@ -2327,6 +2305,26 @@ private PythonType getType(CodegenParameter cp) {
return result;
}

private void applyConstraints(PythonType pythonType, IJsonSchemaValidationProperties cp) {
if (cp.getMaximum() != null) {
if (cp.getExclusiveMaximum()) {
pythonType.constrain(LESS_THAN, cp.getMaximum(), false);
} else {
pythonType.constrain(LESS_OR_EQUAL_TO, cp.getMaximum(), false);
}
}
if (cp.getMinimum() != null) {
if (cp.getExclusiveMinimum()) {
pythonType.constrain(GREATER_THAN, cp.getMinimum(), false);
} else {
pythonType.constrain(GREATER_OR_EQUAL_TO, cp.getMinimum(), false);
}
}
if (cp.getMultipleOf() != null) {
pythonType.constrain("multiple_of", cp.getMultipleOf());
}
}

private String finalizeType(CodegenParameter cp, PythonType pt) {
if (!cp.required || cp.isNullable) {
moduleImports.add("typing", "Optional");
Expand Down
Loading
Loading