Skip to content

Instantly share code, notes, and snippets.

@mgalgs
Created July 31, 2020 19:33
Show Gist options
  • Select an option

  • Save mgalgs/f92d78e6c7ee09b8298cd325bf4f3ed6 to your computer and use it in GitHub Desktop.

Select an option

Save mgalgs/f92d78e6c7ee09b8298cd325bf4f3ed6 to your computer and use it in GitHub Desktop.

Revisions

  1. mgalgs created this gist Jul 31, 2020.
    75 changes: 75 additions & 0 deletions bndbox.dart
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,75 @@
    // Based on https://github.com/shaqian/flutter_realtime_detection

    import 'package:flutter/material.dart' show Border, BoxDecoration, BuildContext, Color, Container, EdgeInsets, FontWeight, Positioned, Stack, StatelessWidget, Text, TextStyle, Widget;
    import 'dart:math' as math;

    class BndBox extends StatelessWidget {
    final List<dynamic> results;
    final int previewH;
    final int previewW;
    final double screenH;
    final double screenW;

    BndBox(this.results, this.previewH, this.previewW, this.screenH, this.screenW);

    @override
    Widget build(BuildContext context) {
    List<Widget> _renderBoxes() {
    return results.map((re) {
    var _x = re["rect"]["x"];
    var _w = re["rect"]["w"];
    var _y = re["rect"]["y"];
    var _h = re["rect"]["h"];
    var scaleW, scaleH, x, y, w, h;

    if (screenH / screenW > previewH / previewW) {
    scaleW = screenH / previewH * previewW;
    scaleH = screenH;
    var difW = (scaleW - screenW) / scaleW;
    x = (_x - difW / 2) * scaleW;
    w = _w * scaleW;
    if (_x < difW / 2) w -= (difW / 2 - _x) * scaleW;
    y = _y * scaleH;
    h = _h * scaleH;
    } else {
    scaleH = screenW / previewW * previewH;
    scaleW = screenW;
    var difH = (scaleH - screenH) / scaleH;
    x = _x * scaleW;
    w = _w * scaleW;
    y = (_y - difH / 2) * scaleH;
    h = _h * scaleH;
    if (_y < difH / 2) h -= (difH / 2 - _y) * scaleH;
    }

    return Positioned(
    left: math.max(0, x),
    top: math.max(0, y),
    width: w,
    height: h,
    child: Container(
    padding: EdgeInsets.only(top: 5.0, left: 5.0),
    decoration: BoxDecoration(
    border: Border.all(
    color: Color.fromRGBO(37, 213, 253, 1.0),
    width: 3.0,
    ),
    ),
    child: Text(
    "${re["detectedClass"]} ${(re["confidenceInClass"] * 100).toStringAsFixed(0)}%",
    style: TextStyle(
    color: Color.fromRGBO(37, 213, 253, 1.0),
    fontSize: 14.0,
    fontWeight: FontWeight.bold,
    ),
    ),
    ),
    );
    }).toList();
    }

    return Stack(
    children: _renderBoxes(),
    );
    }
    }
    101 changes: 101 additions & 0 deletions camera.dart
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,101 @@
    // Based on https://github.com/shaqian/flutter_realtime_detection

    import 'package:flutter/material.dart' show BuildContext, Container, MediaQuery, OverflowBox, State, StatefulWidget, Widget;
    import 'package:camera/camera.dart';
    import 'package:tflite/tflite.dart';
    import 'dart:math' as math;

    typedef void Callback(List<dynamic> list, int h, int w);

    class Camera extends StatefulWidget {
    final List<CameraDescription> cameras;
    final Callback setRecognitions;

    Camera(this.cameras, this.setRecognitions);

    @override
    _CameraState createState() => new _CameraState();
    }

    class _CameraState extends State<Camera> {
    CameraController controller;
    bool isDetecting = false;

    @override
    void initState() {
    super.initState();

    if (widget.cameras == null || widget.cameras.length < 1) {
    print('No camera is found');
    } else {
    controller = new CameraController(
    widget.cameras[0],
    ResolutionPreset.high,
    );
    controller.initialize().then((_) {
    if (!mounted) {
    return;
    }
    setState(() {});

    controller.startImageStream((CameraImage img) {
    if (!isDetecting) {
    isDetecting = true;

    int startTime = new DateTime.now().millisecondsSinceEpoch;

    Tflite.detectObjectOnFrame(
    bytesList: img.planes.map((plane) {
    return plane.bytes;
    }).toList(),
    model: "SSDMobileNet",
    imageHeight: img.height,
    imageWidth: img.width,
    imageMean: 127.5,
    imageStd: 127.5,
    numResultsPerClass: 1,
    threshold: 0.4,
    ).then((recognitions) {
    int endTime = new DateTime.now().millisecondsSinceEpoch;
    print("Detection took ${endTime - startTime}");

    widget.setRecognitions(recognitions, img.height, img.width);

    isDetecting = false;
    });
    }
    });
    });
    }
    }

    @override
    void dispose() {
    controller?.dispose();
    super.dispose();
    }

    @override
    Widget build(BuildContext context) {
    if (controller == null || !controller.value.isInitialized) {
    return Container();
    }

    var tmp = MediaQuery.of(context).size;
    var screenH = math.max(tmp.height, tmp.width);
    var screenW = math.min(tmp.height, tmp.width);
    tmp = controller.value.previewSize;
    var previewH = math.max(tmp.height, tmp.width);
    var previewW = math.min(tmp.height, tmp.width);
    var screenRatio = screenH / screenW;
    var previewRatio = previewH / previewW;

    return OverflowBox(
    maxHeight:
    screenRatio > previewRatio ? screenH : screenW / previewW * previewH,
    maxWidth:
    screenRatio > previewRatio ? screenH / previewH * previewW : screenW,
    child: CameraPreview(controller),
    );
    }
    }
    65 changes: 65 additions & 0 deletions home.dart
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,65 @@
    // Based on https://github.com/shaqian/flutter_realtime_detection

    import 'package:flutter/material.dart' show BuildContext, MediaQuery, Scaffold, Size, Stack, State, StatefulWidget, Widget;
    import 'package:camera/camera.dart';
    import 'package:tflite/tflite.dart';
    import 'dart:math' as math;

    import 'camera.dart';
    import 'bndbox.dart';

    class HomePage extends StatefulWidget {
    final List<CameraDescription> cameras;

    HomePage(this.cameras);

    @override
    _HomePageState createState() => new _HomePageState();
    }

    class _HomePageState extends State<HomePage> {
    List<dynamic> _recognitions;
    int _imageHeight = 0;
    int _imageWidth = 0;

    @override
    void initState() {
    super.initState();
    loadModel();
    }

    void loadModel() async {
    await Tflite.loadModel(
    model: "assets/routespotter_model.tflite",
    labels: "assets/labels.txt");
    }

    setRecognitions(recognitions, imageHeight, imageWidth) {
    setState(() {
    _recognitions = recognitions;
    _imageHeight = imageHeight;
    _imageWidth = imageWidth;
    });
    }

    @override
    Widget build(BuildContext context) {
    Size screen = MediaQuery.of(context).size;
    return Scaffold(
    body: Stack(
    children: [
    Camera(
    widget.cameras,
    setRecognitions,
    ),
    BndBox(
    _recognitions == null ? [] : _recognitions,
    math.max(_imageHeight, _imageWidth),
    math.min(_imageHeight, _imageWidth),
    screen.height,
    screen.width),
    ],
    ),
    );
    }
    }
    104 changes: 104 additions & 0 deletions main.dart
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,104 @@
    // Based on https://github.com/shaqian/flutter_realtime_detection

    import 'dart:async';
    import 'package:flutter/material.dart';
    import 'package:camera/camera.dart' show CameraDescription, CameraException, availableCameras;
    import 'home.dart';

    List<CameraDescription> cameras;
    import 'package:flutter/material.dart' show Border, BoxDecoration, BuildContext, Color, Container, EdgeInsets, FontWeight, Positioned, Stack, StatelessWidget, Text, TextStyle, Widget;
    import 'dart:math' as math;

    class BndBox extends StatelessWidget {
    final List<dynamic> results;
    final int previewH;
    final int previewW;
    final double screenH;
    final double screenW;

    BndBox(this.results, this.previewH, this.previewW, this.screenH, this.screenW);

    @override
    Widget build(BuildContext context) {
    List<Widget> _renderBoxes() {
    return results.map((re) {
    var _x = re["rect"]["x"];
    var _w = re["rect"]["w"];
    var _y = re["rect"]["y"];
    var _h = re["rect"]["h"];
    var scaleW, scaleH, x, y, w, h;

    if (screenH / screenW > previewH / previewW) {
    scaleW = screenH / previewH * previewW;
    scaleH = screenH;
    var difW = (scaleW - screenW) / scaleW;
    x = (_x - difW / 2) * scaleW;
    w = _w * scaleW;
    if (_x < difW / 2) w -= (difW / 2 - _x) * scaleW;
    y = _y * scaleH;
    h = _h * scaleH;
    } else {
    scaleH = screenW / previewW * previewH;
    scaleW = screenW;
    var difH = (scaleH - screenH) / scaleH;
    x = _x * scaleW;
    w = _w * scaleW;
    y = (_y - difH / 2) * scaleH;
    h = _h * scaleH;
    if (_y < difH / 2) h -= (difH / 2 - _y) * scaleH;
    }

    return Positioned(
    left: math.max(0, x),
    top: math.max(0, y),
    width: w,
    height: h,
    child: Container(
    padding: EdgeInsets.only(top: 5.0, left: 5.0),
    decoration: BoxDecoration(
    border: Border.all(
    color: Color.fromRGBO(37, 213, 253, 1.0),
    width: 3.0,
    ),
    ),
    child: Text(
    "${re["detectedClass"]} ${(re["confidenceInClass"] * 100).toStringAsFixed(0)}%",
    style: TextStyle(
    color: Color.fromRGBO(37, 213, 253, 1.0),
    fontSize: 14.0,
    fontWeight: FontWeight.bold,
    ),
    ),
    ),
    );
    }).toList();
    }

    return Stack(
    children: _renderBoxes(),
    );
    }
    }

    Future<Null> main() async {
    WidgetsFlutterBinding.ensureInitialized();
    try {
    cameras = await availableCameras();
    } on CameraException catch (e) {
    print('Error: $e.code\nError Message: $e.message');
    }
    runApp(new MyApp());
    }

    class MyApp extends StatelessWidget {
    @override
    Widget build(BuildContext context) {
    return MaterialApp(
    title: 'tflite real-time detection',
    theme: ThemeData(
    brightness: Brightness.dark,
    ),
    home: HomePage(cameras),
    );
    }
    }