package com.android.build.gradle.internal.tasks.mlkit.codegen;

import com.android.build.gradle.internal.tasks.mlkit.codegen.codeinjector.InjectorUtils;
import com.android.tools.mlkit.MetadataExtractor;
import com.android.tools.mlkit.MlkitNames;
import com.android.tools.mlkit.ModelInfo;
import com.android.tools.mlkit.ModelParsingException;
import com.android.tools.mlkit.TensorInfo;
import com.google.common.base.CaseFormat;
import com.squareup.javapoet.ClassName;
import com.squareup.javapoet.FieldSpec;
import com.squareup.javapoet.JavaFile;
import com.squareup.javapoet.MethodSpec;
import com.squareup.javapoet.TypeSpec;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.Iterator;
import javax.lang.model.element.Modifier;
import org.apache.commons.io.FilenameUtils;
import org.gradle.api.file.DirectoryProperty;
import org.gradle.api.logging.Logger;
import org.gradle.api.logging.Logging;

/* loaded from: input_file:com/android/build/gradle/internal/tasks/mlkit/codegen/TfliteModelGenerator.class */
public class TfliteModelGenerator implements ModelGenerator {
    private static final String FIELD_MODEL = "model";
    private static final String FIELD_METADATA_EXTRACTOR = "extractor";
    private final Logger logger = Logging.getLogger(getClass());
    private final String localModelPath;
    private final MetadataExtractor extractor;
    private final ModelInfo modelInfo;
    private final String className;
    private final String packageName;

    public TfliteModelGenerator(File file, String str, String str2) throws ModelParsingException {
        this.extractor = ModelUtils.createMetadataExtractor(file);
        this.localModelPath = str2;
        this.modelInfo = ModelInfo.buildFrom(this.extractor);
        this.packageName = str;
        this.className = CaseFormat.LOWER_UNDERSCORE.to(CaseFormat.UPPER_CAMEL, FilenameUtils.removeExtension(file.getName()));
    }

    @Override // com.android.build.gradle.internal.tasks.mlkit.codegen.ModelGenerator
    public void generateBuildClass(DirectoryProperty directoryProperty) {
        TypeSpec.Builder addModifiers = TypeSpec.classBuilder(this.className).addModifiers(new Modifier[]{Modifier.PUBLIC, Modifier.FINAL});
        addModifiers.addJavadoc(this.modelInfo.getModelDescription(), new Object[0]);
        buildFields(addModifiers);
        buildConstructor(addModifiers);
        buildCreateInputsMethod(addModifiers);
        buildGetAssociatedFileMethod(addModifiers);
        buildRunMethod(addModifiers);
        buildInnerClass(addModifiers);
        try {
            JavaFile.builder(this.packageName, addModifiers.build()).build().writeTo((File) directoryProperty.getAsFile().get());
        } catch (IOException e) {
            this.logger.debug("Failed to write mlkit generated java file");
        }
    }

    private void buildFields(TypeSpec.Builder builder) {
        Iterator<TensorInfo> it = this.modelInfo.getInputs().iterator();
        while (it.hasNext()) {
            InjectorUtils.getFieldInjector().inject(builder, it.next());
        }
        Iterator<TensorInfo> it2 = this.modelInfo.getOutputs().iterator();
        while (it2.hasNext()) {
            InjectorUtils.getFieldInjector().inject(builder, it2.next());
        }
        builder.addField(FieldSpec.builder(ClassNames.MODEL, FIELD_MODEL, new Modifier[0]).addModifiers(new Modifier[]{Modifier.PRIVATE, Modifier.FINAL}).build());
    }

    private void buildGetAssociatedFileMethod(TypeSpec.Builder builder) {
        MethodSpec.Builder returns = MethodSpec.methodBuilder("getAssociatedFile").addParameter(ClassNames.CONTEXT, "context", new Modifier[0]).addParameter(String.class, "fileName", new Modifier[0]).addException(IOException.class).returns(InputStream.class);
        returns.addStatement("$T inputStream = context.getAssets().open($S)", new Object[]{InputStream.class, this.localModelPath}).addStatement("$T zipFile = new $T(new $T($T.toByteArray(inputStream)))", new Object[]{ClassNames.ZIP_FILE, ClassNames.ZIP_FILE, ClassNames.SEEKABLE_IN_MEMORY_BYTE_CHANNEL, ClassNames.IO_UTILS}).addStatement("return zipFile.getRawInputStream(zipFile.getEntry(fileName))", new Object[0]);
        builder.addMethod(returns.build());
    }

    private void buildInnerClass(TypeSpec.Builder builder) {
        InjectorUtils.getOutputsClassInjector().inject(builder, this.modelInfo.getOutputs());
        InjectorUtils.getInputsClassInjector().inject(builder, this.modelInfo.getInputs());
    }

    private void buildConstructor(TypeSpec.Builder builder) {
        MethodSpec.Builder addStatement = MethodSpec.constructorBuilder().addModifiers(new Modifier[]{Modifier.PUBLIC}).addParameter(ClassNames.CONTEXT, "context", new Modifier[0]).addException(ClassNames.IO_EXCEPTION).addStatement("$L = new $T.Builder(context, $S).build()", new Object[]{FIELD_MODEL, ClassNames.MODEL, this.localModelPath});
        for (TensorInfo tensorInfo : this.modelInfo.getInputs()) {
            InjectorUtils.getInputProcessorInjector(tensorInfo).inject(addStatement, tensorInfo);
        }
        for (TensorInfo tensorInfo2 : this.modelInfo.getOutputs()) {
            InjectorUtils.getOutputProcessorInjector(tensorInfo2).inject(addStatement, tensorInfo2);
            InjectorUtils.getAssociatedFileInjector().inject(addStatement, tensorInfo2);
        }
        builder.addMethod(addStatement.build());
    }

    private void buildRunMethod(TypeSpec.Builder builder) {
        ClassName nestedClass = ClassName.get(this.packageName, this.className, new String[0]).nestedClass(MlkitNames.OUTPUTS);
        MethodSpec.Builder returns = MethodSpec.methodBuilder("run").addModifiers(new Modifier[]{Modifier.PUBLIC}).addParameter(ClassName.get(this.packageName, this.className, new String[0]).nestedClass(MlkitNames.INPUTS), "inputs", new Modifier[0]).returns(nestedClass);
        returns.addStatement("$T $L = new $T()", new Object[]{nestedClass, "outputs", nestedClass});
        returns.addStatement("$L.run($L.getBuffer(), $L.getBuffer())", new Object[]{FIELD_MODEL, "inputs", "outputs"});
        returns.addStatement("return $L", new Object[]{"outputs"});
        builder.addMethod(returns.build());
    }

    private void buildCreateInputsMethod(TypeSpec.Builder builder) {
        builder.addMethod(MethodSpec.methodBuilder("createInputs").addModifiers(new Modifier[]{Modifier.PUBLIC}).returns(ClassName.get(this.packageName, this.className, new String[0]).nestedClass(MlkitNames.INPUTS)).addStatement("return new $L()", new Object[]{MlkitNames.INPUTS}).build());
    }
}
